From be87d8169d98eb358891d155ce84cff311ec27c2 Mon Sep 17 00:00:00 2001 From: Bruce Hill Date: Sun, 9 Feb 2025 13:57:54 -0500 Subject: [PATCH] Convert the logic for finding closed variables to a more pure functional style with fewer side effects --- compile.c | 379 ++++++++++++++++++++++++++++++++++++++++---------- environment.c | 19 +-- environment.h | 9 +- repl.c | 6 - typecheck.c | 6 +- 5 files changed, 311 insertions(+), 108 deletions(-) diff --git a/compile.c b/compile.c index 9c1a16f..ecd03b8 100644 --- a/compile.c +++ b/compile.c @@ -179,40 +179,291 @@ CORD compile_maybe_incref(env_t *env, ast_t *ast, type_t *t) return compile_to_type(env, ast, t); } - -static Table_t *get_closed_vars(env_t *env, ast_t *lambda_ast) +static void add_closed_vars(Table_t *closed_vars, env_t *enclosing_scope, env_t *env, ast_t *ast) +{ + if (ast == NULL) + return; + + switch (ast->tag) { + case Var: { + binding_t *b = get_binding(enclosing_scope, Match(ast, Var)->name); + if (b) { + binding_t *shadow = get_binding(env, Match(ast, Var)->name); + if (!shadow || shadow == b) + Table$str_set(closed_vars, Match(ast, Var)->name, b); + } + break; + } + case TextJoin: { + for (ast_list_t *child = Match(ast, TextJoin)->children; child; child = child->next) + add_closed_vars(closed_vars, enclosing_scope, env, child->ast); + break; + } + case PrintStatement: { + for (ast_list_t *child = Match(ast, PrintStatement)->to_print; child; child = child->next) + add_closed_vars(closed_vars, enclosing_scope, env, child->ast); + break; + } + case Declare: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Declare)->value); + bind_statement(env, ast); + break; + } + case Assign: { + for (ast_list_t *target = Match(ast, Assign)->targets; target; target = target->next) + add_closed_vars(closed_vars, enclosing_scope, env, target->ast); + for (ast_list_t *value = Match(ast, Assign)->values; value; value = value->next) + add_closed_vars(closed_vars, enclosing_scope, env, value->ast); + break; + } + case BinaryOp: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, BinaryOp)->lhs); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, BinaryOp)->rhs); + break; + } + case UpdateAssign: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, UpdateAssign)->lhs); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, UpdateAssign)->rhs); + break; + } + case Not: case Negative: case HeapAllocate: case StackReference: case Mutexed: { + // UNSAFE: + ast_t *value = ast->__data.Not.value; + // END UNSAFE + add_closed_vars(closed_vars, enclosing_scope, env, value); + break; + } + case Holding: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Holding)->mutexed); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Holding)->body); + break; + } + case Min: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Min)->lhs); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Min)->rhs); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Min)->key); + break; + } + case Max: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Max)->lhs); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Max)->rhs); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Max)->key); + break; + } + case Array: { + for (ast_list_t *item = Match(ast, Array)->items; item; item = item->next) + add_closed_vars(closed_vars, enclosing_scope, env, item->ast); + break; + } + case Set: { + for (ast_list_t *item = Match(ast, Set)->items; item; item = item->next) + add_closed_vars(closed_vars, enclosing_scope, env, item->ast); + break; + } + case Table: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Table)->default_value); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Table)->fallback); + for (ast_list_t *entry = Match(ast, Table)->entries; entry; entry = entry->next) + add_closed_vars(closed_vars, enclosing_scope, env, entry->ast); + break; + } + case TableEntry: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, TableEntry)->key); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, TableEntry)->value); + break; + } + case Comprehension: { + auto comp = Match(ast, Comprehension); + if (comp->expr->tag == Comprehension) { // Nested comprehension + ast_t *body = comp->filter ? WrapAST(ast, If, .condition=comp->filter, .body=comp->expr) : comp->expr; + ast_t *loop = WrapAST(ast, For, .vars=comp->vars, .iter=comp->iter, .body=body); + return add_closed_vars(closed_vars, enclosing_scope, env, loop); + } + + // Array/Set/Table comprehension: + ast_t *body = comp->expr; + if (comp->filter) + body = WrapAST(comp->expr, If, .condition=comp->filter, .body=body); + ast_t *loop = WrapAST(ast, For, .vars=comp->vars, .iter=comp->iter, .body=body); + add_closed_vars(closed_vars, enclosing_scope, env, loop); + break; + } + case Lambda: { + auto lambda = Match(ast, Lambda); + env_t *lambda_scope = fresh_scope(env); + for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) + set_binding(lambda_scope, arg->name, new(binding_t, .type=get_arg_ast_type(env, arg))); + add_closed_vars(closed_vars, enclosing_scope, lambda_scope, lambda->body); + break; + } + case FunctionCall: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, FunctionCall)->fn); + for (arg_ast_t *arg = Match(ast, FunctionCall)->args; arg; arg = arg->next) + add_closed_vars(closed_vars, enclosing_scope, env, arg->value); + break; + } + case MethodCall: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, MethodCall)->self); + for (arg_ast_t *arg = Match(ast, MethodCall)->args; arg; arg = arg->next) + add_closed_vars(closed_vars, enclosing_scope, env, arg->value); + break; + } + case Block: { + env = fresh_scope(env); + for (ast_list_t *statement = Match(ast, Block)->statements; statement; statement = statement->next) + add_closed_vars(closed_vars, enclosing_scope, env, statement->ast); + break; + } + case For: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, For)->iter); + env_t *body_scope = for_scope(env, ast); + add_closed_vars(closed_vars, enclosing_scope, body_scope, Match(ast, For)->body); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, For)->empty); + break; + } + case While: { + auto while_ = Match(ast, While); + add_closed_vars(closed_vars, enclosing_scope, env, while_->condition); + env_t *scope = fresh_scope(env); + add_closed_vars(closed_vars, enclosing_scope, scope, while_->body); + break; + } + case If: { + auto if_ = Match(ast, If); + ast_t *condition = if_->condition; + add_closed_vars(closed_vars, enclosing_scope, env, condition); + if (condition->tag == Declare) { + env_t *truthy_scope = fresh_scope(env); + bind_statement(truthy_scope, condition); + ast_t *var = Match(condition, Declare)->var; + type_t *cond_t = get_type(truthy_scope, var); + if (cond_t->tag == OptionalType) { + set_binding(truthy_scope, Match(var, Var)->name, + new(binding_t, .type=Match(cond_t, OptionalType)->type)); + } + add_closed_vars(closed_vars, enclosing_scope, truthy_scope, if_->body); + add_closed_vars(closed_vars, enclosing_scope, env, if_->else_body); + } else { + env_t *truthy_scope = env; + type_t *cond_t = get_type(env, condition); + if (condition->tag == Var && cond_t->tag == OptionalType) { + truthy_scope = fresh_scope(env); + set_binding(truthy_scope, Match(condition, Var)->name, + new(binding_t, .type=Match(cond_t, OptionalType)->type)); + } + add_closed_vars(closed_vars, enclosing_scope, truthy_scope, if_->body); + add_closed_vars(closed_vars, enclosing_scope, env, if_->else_body); + } + break; + } + case When: { + auto when = Match(ast, When); + add_closed_vars(closed_vars, enclosing_scope, env, when->subject); + type_t *subject_t = get_type(env, when->subject); + + auto enum_t = Match(subject_t, EnumType); + for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { + const char *clause_tag_name = Match(clause->tag_name, Var)->name; + type_t *tag_type = NULL; + for (tag_t *tag = enum_t->tags; tag; tag = tag->next) { + if (streq(tag->name, clause_tag_name)) { + tag_type = tag->type; + break; + } + } + assert(tag_type); + env_t *scope = env; + + auto tag_struct = Match(tag_type, StructType); + if (clause->args && !clause->args->next && tag_struct->fields && tag_struct->fields->next) { + scope = fresh_scope(scope); + set_binding(scope, Match(clause->args->ast, Var)->name, new(binding_t, .type=tag_type)); + } else if (clause->args) { + scope = fresh_scope(scope); + ast_list_t *var = clause->args; + arg_t *field = tag_struct->fields; + while (var || field) { + if (!var) + code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, clause_tag_name, field->name); + if (!field) + code_err(var->ast, "This is one more field than %T has", subject_t); + set_binding(scope, Match(var->ast, Var)->name, new(binding_t, .type=field->type)); + var = var->next; + field = field->next; + } + } + add_closed_vars(closed_vars, enclosing_scope, scope, clause->body); + } + if (when->else_body) + add_closed_vars(closed_vars, enclosing_scope, env, when->else_body); + break; + } + case Repeat: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Repeat)->body); + break; + } + case Reduction: { + auto reduction = Match(ast, Reduction); + static int64_t next_id = 1; + ast_t *item = FakeAST(Var, heap_strf("$it%ld", next_id++)); + ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=FakeAST(Pass)); + env_t *scope = for_scope(env, loop); + add_closed_vars(closed_vars, enclosing_scope, scope, reduction->key ? reduction->key : item); + break; + } + case Defer: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Defer)->body); + break; + } + case Return: { + ast_t *ret = Match(ast, Return)->value; + if (ret) add_closed_vars(closed_vars, enclosing_scope, env, ret); + break; + } + case Index: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Index)->indexed); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Index)->index); + break; + } + case FieldAccess: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, FieldAccess)->fielded); + break; + } + case Optional: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Optional)->value); + break; + } + case NonOptional: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, NonOptional)->value); + break; + } + case DocTest: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, DocTest)->expr); + break; + } + case Deserialize: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Deserialize)->value); + break; + } + case Use: case FunctionDef: case StructDef: case EnumDef: case LangDef: { + errx(1, "Definitions should not be reachable in a closure."); + } + default: + break; + } +} + +static Table_t get_closed_vars(env_t *env, arg_ast_t *args, ast_t *block) { - auto lambda = Match(lambda_ast, Lambda); env_t *body_scope = fresh_scope(env); - body_scope->code = new(compilation_unit_t); // Don't put any code in the headers or anything - for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) { + for (arg_ast_t *arg = args; arg; arg = arg->next) { type_t *arg_type = get_arg_ast_type(env, arg); set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=CORD_cat("_$", arg->name))); } - fn_ctx_t fn_ctx = (fn_ctx_t){ - .parent=env->fn_ctx, - .closure_scope=env->locals, - .closed_vars=new(Table_t), - }; - body_scope->fn_ctx = &fn_ctx; - body_scope->locals->fallback = env->globals; - type_t *ret_t = get_type(body_scope, lambda->body); - if (ret_t->tag == ReturnType) - ret_t = Match(ret_t, ReturnType)->ret; - fn_ctx.return_type = ret_t; - - // Find which variables are captured in the closure: - env_t *tmp_scope = fresh_scope(body_scope); - for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next) { - type_t *stmt_type = get_type(tmp_scope, stmt->ast); - if (stmt->next || (stmt_type->tag == VoidType || stmt_type->tag == AbortType || get_type(tmp_scope, stmt->ast)->tag == ReturnType)) - (void)compile_statement(tmp_scope, stmt->ast); - else - (void)compile(tmp_scope, stmt->ast); - bind_statement(tmp_scope, stmt->ast); - } - return fn_ctx.closed_vars; + Table_t closed_vars = {}; + add_closed_vars(&closed_vars, env, body_scope, block); + return closed_vars; } CORD compile_declaration(type_t *t, CORD name) @@ -953,13 +1204,7 @@ static CORD _compile_statement(env_t *env, ast_t *ast) set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=CORD_cat("_$", arg->name))); } - fn_ctx_t fn_ctx = (fn_ctx_t){ - .parent=NULL, - .return_type=ret_t, - .closure_scope=NULL, - .closed_vars=NULL, - }; - body_scope->fn_ctx = &fn_ctx; + body_scope->fn_ret = ret_t; type_t *body_type = get_type(body_scope, fndef->body); if (ret_t->tag != VoidType && ret_t->tag != AbortType && body_type->tag != AbortType && body_type->tag != ReturnType) @@ -1132,13 +1377,13 @@ static CORD _compile_statement(env_t *env, ast_t *ast) case Pass: return ";"; case Defer: { ast_t *body = Match(ast, Defer)->body; - Table_t *closed_vars = get_closed_vars(env, FakeAST(Lambda, .args=NULL, .body=body)); + Table_t closed_vars = get_closed_vars(env, NULL, body); static int defer_id = 0; env_t *defer_env = fresh_scope(env); CORD code = CORD_EMPTY; - for (int64_t i = 1; i <= Table$length(*closed_vars); i++) { - struct { const char *name; binding_t *b; } *entry = Table$entry(*closed_vars, i); + for (int64_t i = 1; i <= Table$length(closed_vars); i++) { + struct { const char *name; binding_t *b; } *entry = Table$entry(closed_vars, i); if (entry->b->type->tag == ModuleType) continue; if (CORD_ncmp(entry->b->code, 0, "userdata->", 0, strlen("userdata->")) == 0) { @@ -1149,9 +1394,6 @@ static CORD _compile_statement(env_t *env, ast_t *ast) code, compile_declaration(entry->b->type, defer_name), " = ", entry->b->code, ";\n"); set_binding(defer_env, entry->name, new(binding_t, .type=entry->b->type, .code=defer_name)); } - - if (env->fn_ctx->closed_vars) - Table$str_set(env->fn_ctx->closed_vars, entry->name, entry->b); } env->deferred = new(deferral_t, .defer_env=defer_env, .block=body, .next=env->deferred); return code; @@ -1175,9 +1417,8 @@ static CORD _compile_statement(env_t *env, ast_t *ast) return CORD_cat(code, ", yes);"); } case Return: { - if (!env->fn_ctx) code_err(ast, "This return statement is not inside any function"); + if (!env->fn_ret) code_err(ast, "This return statement is not inside any function"); auto ret = Match(ast, Return)->value; - assert(env->fn_ctx->return_type); CORD code = CORD_EMPTY; for (deferral_t *deferred = env->deferred; deferred; deferred = deferred->next) { @@ -1185,20 +1426,20 @@ static CORD _compile_statement(env_t *env, ast_t *ast) } if (ret) { - if (env->fn_ctx->return_type->tag == VoidType || env->fn_ctx->return_type->tag == AbortType) + if (env->fn_ret->tag == VoidType || env->fn_ret->tag == AbortType) code_err(ast, "This function is not supposed to return any values, according to its type signature"); - env = with_enum_scope(env, env->fn_ctx->return_type); - CORD value = compile_to_type(env, ret, env->fn_ctx->return_type); + env = with_enum_scope(env, env->fn_ret); + CORD value = compile_to_type(env, ret, env->fn_ret); if (env->deferred) { - code = CORD_all(compile_declaration(env->fn_ctx->return_type, "ret"), " = ", value, ";\n", code); + code = CORD_all(compile_declaration(env->fn_ret, "ret"), " = ", value, ";\n", code); value = "ret"; } return CORD_all(code, "return ", value, ";"); } else { - if (env->fn_ctx->return_type->tag != VoidType) - code_err(ast, "This function expects you to return a %T value", env->fn_ctx->return_type); + if (env->fn_ret->tag != VoidType) + code_err(ast, "This function expects you to return a %T value", env->fn_ret); return CORD_all(code, "return;"); } } @@ -2826,19 +3067,12 @@ CORD compile(env_t *env, ast_t *ast) name, CORD_quoted(type_to_cord(get_type(env, ast))), file_base_name(ast->file->filename), get_line_number(ast->file, ast->start))); env_t *body_scope = fresh_scope(env); + body_scope->deferred = NULL; for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) { type_t *arg_type = get_arg_ast_type(env, arg); - set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=CORD_cat("_$", arg->name))); + set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=CORD_all("_$", arg->name))); } - fn_ctx_t fn_ctx = (fn_ctx_t){ - .parent=env->fn_ctx, - .closure_scope=env->locals, - .closed_vars=new(Table_t), - }; - body_scope->fn_ctx = &fn_ctx; - body_scope->locals->fallback = env->globals; - body_scope->deferred = NULL; type_t *ret_t = get_type(body_scope, lambda->body); if (ret_t->tag == ReturnType) ret_t = Match(ret_t, ReturnType)->ret; @@ -2852,21 +3086,13 @@ CORD compile(env_t *env, ast_t *ast) declared, ret_t); } - fn_ctx.return_type = ret_t; + body_scope->fn_ret = ret_t; - if (env->fn_ctx->closed_vars) { - for (int64_t i = 1; i <= Table$length(*env->fn_ctx->closed_vars); i++) { - struct { const char *name; binding_t *b; } *entry = Table$entry(*env->fn_ctx->closed_vars, i); - set_binding(body_scope, entry->name, new(binding_t, .type=entry->b->type, .code=CORD_cat("userdata->", entry->name))); - Table$str_set(fn_ctx.closed_vars, entry->name, entry->b); - } - } - - Table_t *closed_vars = get_closed_vars(env, ast); - if (Table$length(*closed_vars) > 0) { // Create a typedef for the lambda's closure userdata + Table_t closed_vars = get_closed_vars(env, lambda->args, ast); + if (Table$length(closed_vars) > 0) { // Create a typedef for the lambda's closure userdata CORD def = "typedef struct {"; - for (int64_t i = 1; i <= Table$length(*closed_vars); i++) { - struct { const char *name; binding_t *b; } *entry = Table$entry(*closed_vars, i); + for (int64_t i = 1; i <= Table$length(closed_vars); i++) { + struct { const char *name; binding_t *b; } *entry = Table$entry(closed_vars, i); if (has_stack_memory(entry->b->type)) code_err(ast, "This function is holding onto a reference to %T stack memory in the variable `%s`, but the function may outlive the stack memory", entry->b->type, entry->name); @@ -2886,16 +3112,19 @@ CORD compile(env_t *env, ast_t *ast) } CORD userdata; - if (Table$length(*closed_vars) == 0) { + if (Table$length(closed_vars) == 0) { code = CORD_cat(code, "void *)"); userdata = "NULL"; } else { userdata = CORD_all("new(", name, "$userdata_t"); - for (int64_t i = 1; i <= Table$length(*closed_vars); i++) { - struct { const char *name; binding_t *b; } *entry = Table$entry(*closed_vars, i); + for (int64_t i = 1; i <= Table$length(closed_vars); i++) { + struct { const char *name; binding_t *b; } *entry = Table$entry(closed_vars, i); if (entry->b->type->tag == ModuleType) continue; - CORD binding_code = get_binding(env, entry->name)->code; + binding_t *b = get_binding(env, entry->name); + if (!b) printf("Couldn't find: %s\n", entry->name); + assert(b); + CORD binding_code = b->code; if (entry->b->type->tag == ArrayType) userdata = CORD_all(userdata, ", ARRAY_COPY(", binding_code, ")"); else if (entry->b->type->tag == TableType || entry->b->type->tag == SetType) diff --git a/environment.c b/environment.c index 86a7ff1..1b33fd7 100644 --- a/environment.c +++ b/environment.c @@ -646,22 +646,9 @@ env_t *namespace_env(env_t *env, const char *namespace_name) return ns_env; } -binding_t *get_binding(env_t *env, const char *name) +PUREFUNC binding_t *get_binding(env_t *env, const char *name) { - binding_t *b = Table$str_get(*env->locals, name); - if (!b) { - for (fn_ctx_t *fn_ctx = env->fn_ctx; fn_ctx; fn_ctx = fn_ctx->parent) { - if (!fn_ctx->closure_scope) continue; - b = Table$str_get(*fn_ctx->closure_scope, name); - if (b) { - Table$str_set(env->fn_ctx->closed_vars, name, b); - binding_t *b2 = new(binding_t, .type=b->type, .code=CORD_all("userdata->", name)); - Table$str_set(env->locals, name, b2); - return b2; - } - } - } - return b; + return Table$str_get(*env->locals, name); } binding_t *get_namespace_binding(env_t *env, ast_t *self, const char *name) @@ -702,7 +689,7 @@ binding_t *get_namespace_binding(env_t *env, ast_t *self, const char *name) return NULL; } -binding_t *get_lang_escape_function(env_t *env, const char *lang_name, type_t *type_to_escape) +PUREFUNC binding_t *get_lang_escape_function(env_t *env, const char *lang_name, type_t *type_to_escape) { if (!lang_name) lang_name = "Text"; binding_t *typeinfo = get_binding(env, lang_name); diff --git a/environment.h b/environment.h index 616bfc5..465ecfa 100644 --- a/environment.h +++ b/environment.h @@ -16,13 +16,6 @@ typedef struct { CORD function_naming; } compilation_unit_t; -typedef struct fn_ctx_s { - struct fn_ctx_s *parent; - type_t *return_type; - Table_t *closure_scope; - Table_t *closed_vars; -} fn_ctx_t; - typedef struct deferral_s { struct deferral_s *next; struct env_s *defer_env; @@ -49,7 +42,7 @@ typedef struct env_s { // - Raw 'use' string for module imports Table_t *imports; compilation_unit_t *code; - fn_ctx_t *fn_ctx; + type_t *fn_ret; loop_ctx_t *loop_ctx; deferral_t *deferred; CORD libname; // Currently compiling library name (if any) diff --git a/repl.c b/repl.c index 29ed407..3b2e1af 100644 --- a/repl.c +++ b/repl.c @@ -36,12 +36,6 @@ typedef struct { static PUREFUNC repl_binding_t *get_repl_binding(env_t *env, const char *name) { repl_binding_t *b = Table$str_get(*env->locals, name); - if (b) return b; - for (fn_ctx_t *fn_ctx = env->fn_ctx; fn_ctx; fn_ctx = fn_ctx->parent) { - if (!fn_ctx->closure_scope) continue; - b = Table$str_get(*fn_ctx->closure_scope, name); - if (b) return b; - } return b; } diff --git a/typecheck.c b/typecheck.c index 877d56a..a2530bc 100644 --- a/typecheck.c +++ b/typecheck.c @@ -908,9 +908,9 @@ type_t *get_type(env_t *env, ast_t *ast) case Return: { ast_t *val = Match(ast, Return)->value; // Support unqualified enum return values: - if (env->fn_ctx && env->fn_ctx->return_type && env->fn_ctx->return_type->tag == EnumType) { + if (env->fn_ret && env->fn_ret->tag == EnumType) { env = fresh_scope(env); - auto enum_ = Match(env->fn_ctx->return_type, EnumType); + auto enum_ = Match(env->fn_ret, EnumType); env_t *ns_env = enum_->env; for (tag_t *tag = enum_->tags; tag; tag = tag->next) { if (get_binding(env, tag->name)) @@ -1163,7 +1163,7 @@ type_t *get_type(env_t *env, ast_t *ast) arg_t *args = NULL; env_t *scope = fresh_scope(env); // For now, just use closed variables in scope normally for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) { - type_t *t = arg->type ? parse_type_ast(env, arg->type) : get_type(env, arg->value); + type_t *t = get_arg_ast_type(env, arg); args = new(arg_t, .name=arg->name, .type=t, .next=args); set_binding(scope, arg->name, new(binding_t, .type=t)); }