diff --git a/compile.c b/compile.c index 068afb4..66aba7a 100644 --- a/compile.c +++ b/compile.c @@ -23,7 +23,6 @@ typedef ast_t* (*comprehension_body_t)(ast_t*, ast_t*); static CORD compile_to_pointer_depth(env_t *env, ast_t *ast, int64_t target_depth, bool needs_incref); -static env_t *with_enum_scope(env_t *env, type_t *t); static CORD compile_math_method(env_t *env, binop_e op, ast_t *lhs, ast_t *rhs, type_t *required_type); static CORD compile_string(env_t *env, ast_t *ast, CORD color); static CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t *call_args); @@ -1908,21 +1907,6 @@ CORD compile_to_type(env_t *env, ast_t *ast, type_t *t) return code; } -env_t *with_enum_scope(env_t *env, type_t *t) -{ - if (t->tag != EnumType) return env; - env = fresh_scope(env); - env_t *ns_env = Match(t, EnumType)->env; - for (tag_t *tag = Match(t, EnumType)->tags; tag; tag = tag->next) { - if (get_binding(env, tag->name)) - continue; - binding_t *b = get_binding(ns_env, tag->name); - assert(b); - Table$str_set(env->locals, tag->name, b); - } - return env; -} - CORD compile_int_to_type(env_t *env, ast_t *ast, type_t *target) { if (ast->tag != Int) { @@ -4104,6 +4088,11 @@ CORD compile_function(env_t *env, CORD name_code, ast_t *ast, CORD *staticdefs) } env_t *body_scope = fresh_scope(env); + while (body_scope->namespace && body_scope->namespace->parent) { + body_scope->locals->fallback = body_scope->locals->fallback->fallback; + body_scope->namespace = body_scope->namespace->parent; + } + body_scope->deferred = NULL; body_scope->namespace = NULL; for (arg_ast_t *arg = args; arg; arg = arg->next) { @@ -4121,7 +4110,7 @@ CORD compile_function(env_t *env, CORD name_code, ast_t *ast, CORD *staticdefs) if (body_type->tag == AbortType) code_err(ast, "This function will always abort before it reaches the end, but it's declared as having a Void return. It should be declared as an Abort return instead."); } else { - if (body_type->tag != ReturnType) + if (body_type->tag != ReturnType && body_type->tag != AbortType) code_err(ast, "This function can reach the end without returning a %T value!", ret_t); } diff --git a/environment.c b/environment.c index b871bff..e09ef75 100644 --- a/environment.c +++ b/environment.c @@ -641,6 +641,14 @@ env_t *load_module_env(env_t *env, ast_t *ast) return module_env; } +env_t *global_scope(env_t *env) +{ + env_t *scope = new(env_t); + *scope = *env; + scope->locals = new(Table_t, .fallback=env->globals); + return scope; +} + env_t *namespace_scope(env_t *env) { env_t *scope = new(env_t); @@ -657,6 +665,25 @@ env_t *fresh_scope(env_t *env) return scope; } +env_t *with_enum_scope(env_t *env, type_t *t) +{ + while (t->tag == OptionalType) + t = Match(t, OptionalType)->type; + + if (t->tag != EnumType) return env; + env = fresh_scope(env); + env_t *ns_env = Match(t, EnumType)->env; + for (tag_t *tag = Match(t, EnumType)->tags; tag; tag = tag->next) { + if (get_binding(env, tag->name)) + continue; + binding_t *b = get_binding(ns_env, tag->name); + assert(b); + Table$str_set(env->locals, tag->name, b); + } + return env; +} + + env_t *for_scope(env_t *env, ast_t *ast) { auto for_ = Match(ast, For); diff --git a/environment.h b/environment.h index 4566d37..a6fabc8 100644 --- a/environment.h +++ b/environment.h @@ -59,9 +59,11 @@ env_t *new_compilation_unit(CORD libname); env_t *load_module_env(env_t *env, ast_t *ast); CORD namespace_prefix(env_t *env, namespace_t *ns); env_t *get_namespace_by_type(env_t *env, type_t *t); +env_t *global_scope(env_t *env); env_t *namespace_scope(env_t *env); env_t *fresh_scope(env_t *env); env_t *for_scope(env_t *env, ast_t *ast); +env_t *with_enum_scope(env_t *env, type_t *t); env_t *namespace_env(env_t *env, const char *namespace_name); __attribute__((format(printf, 4, 5))) _Noreturn void compiler_err(file_t *f, const char *start, const char *end, const char *fmt, ...); diff --git a/typecheck.c b/typecheck.c index 5f97b4d..9d3fab7 100644 --- a/typecheck.c +++ b/typecheck.c @@ -300,9 +300,9 @@ void bind_statement(env_t *env, ast_t *statement) auto def = Match(statement, FunctionDef); const char *name = Match(def->name, Var)->name; type_t *type = get_function_def_type(env, statement); - binding_t *clobber = get_binding(env, name); - if (clobber) - code_err(def->name, "A %T called '%s' has already been defined", clobber->type, name); + // binding_t *clobber = get_binding(env, name); + // if (clobber) + // code_err(def->name, "A %T called '%s' has already been defined", clobber->type, name); CORD code = CORD_all(namespace_prefix(env, env->namespace), name); set_binding(env, name, type, code); break; @@ -958,19 +958,8 @@ type_t *get_type(env_t *env, ast_t *ast) } case Return: { ast_t *val = Match(ast, Return)->value; - // Support unqualified enum return values: - if (env->fn_ret && env->fn_ret->tag == EnumType) { - env = fresh_scope(env); - auto enum_ = Match(env->fn_ret, EnumType); - env_t *ns_env = enum_->env; - for (tag_t *tag = enum_->tags; tag; tag = tag->next) { - if (get_binding(env, tag->name)) - continue; - binding_t *b = get_binding(ns_env, tag->name); - assert(b); - Table$str_set(env->locals, tag->name, b); - } - } + if (env->fn_ret) + env = with_enum_scope(env, env->fn_ret); return Type(ReturnType, .ret=(val ? get_type(env, val) : Type(VoidType))); } case Stop: case Skip: {