diff options
| author | Bruce Hill <bruce@bruce-hill.com> | 2025-08-24 16:21:16 -0400 |
|---|---|---|
| committer | Bruce Hill <bruce@bruce-hill.com> | 2025-08-24 16:21:16 -0400 |
| commit | 5e82b074667627d1e6f44247396bae80f1c02fff (patch) | |
| tree | 4c23178f6ea9e400507f6345b0b58a7ce54ab0aa /src | |
| parent | bd312d20901b968e0603359e8cd01276d114031a (diff) | |
Move function logic into functions files
Diffstat (limited to 'src')
| -rw-r--r-- | src/compile.c | 372 | ||||
| -rw-r--r-- | src/compile/functions.c | 376 | ||||
| -rw-r--r-- | src/compile/functions.h | 2 |
3 files changed, 379 insertions, 371 deletions
diff --git a/src/compile.c b/src/compile.c index c906112b..fa47e63c 100644 --- a/src/compile.c +++ b/src/compile.c @@ -54,289 +54,6 @@ Text_t compile_maybe_incref(env_t *env, ast_t *ast, type_t *t) { return compile_to_type(env, ast, t); } -static void add_closed_vars(Table_t *closed_vars, env_t *enclosing_scope, env_t *env, ast_t *ast) { - if (ast == NULL) return; - - switch (ast->tag) { - case Var: { - binding_t *b = get_binding(enclosing_scope, Match(ast, Var)->name); - if (b) { - binding_t *shadow = get_binding(env, Match(ast, Var)->name); - if (!shadow || shadow == b) Table$str_set(closed_vars, Match(ast, Var)->name, b); - } - break; - } - case TextJoin: { - for (ast_list_t *child = Match(ast, TextJoin)->children; child; child = child->next) - add_closed_vars(closed_vars, enclosing_scope, env, child->ast); - break; - } - case Declare: { - ast_t *value = Match(ast, Declare)->value; - add_closed_vars(closed_vars, enclosing_scope, env, value); - bind_statement(env, ast); - break; - } - case Assign: { - for (ast_list_t *target = Match(ast, Assign)->targets; target; target = target->next) - add_closed_vars(closed_vars, enclosing_scope, env, target->ast); - for (ast_list_t *value = Match(ast, Assign)->values; value; value = value->next) - add_closed_vars(closed_vars, enclosing_scope, env, value->ast); - break; - } - case BINOP_CASES: { - binary_operands_t binop = BINARY_OPERANDS(ast); - add_closed_vars(closed_vars, enclosing_scope, env, binop.lhs); - add_closed_vars(closed_vars, enclosing_scope, env, binop.rhs); - break; - } - case Not: - case Negative: - case HeapAllocate: - case StackReference: { - // UNSAFE: - ast_t *value = ast->__data.Not.value; - // END UNSAFE - add_closed_vars(closed_vars, enclosing_scope, env, value); - break; - } - case Min: { - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Min)->lhs); - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Min)->rhs); - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Min)->key); - break; - } - case Max: { - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Max)->lhs); - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Max)->rhs); - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Max)->key); - break; - } - case List: { - for (ast_list_t *item = Match(ast, List)->items; item; item = item->next) - add_closed_vars(closed_vars, enclosing_scope, env, item->ast); - break; - } - case Set: { - for (ast_list_t *item = Match(ast, Set)->items; item; item = item->next) - add_closed_vars(closed_vars, enclosing_scope, env, item->ast); - break; - } - case Table: { - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Table)->default_value); - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Table)->fallback); - for (ast_list_t *entry = Match(ast, Table)->entries; entry; entry = entry->next) - add_closed_vars(closed_vars, enclosing_scope, env, entry->ast); - break; - } - case TableEntry: { - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, TableEntry)->key); - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, TableEntry)->value); - break; - } - case Comprehension: { - DeclareMatch(comp, ast, Comprehension); - if (comp->expr->tag == Comprehension) { // Nested comprehension - ast_t *body = comp->filter ? WrapAST(ast, If, .condition = comp->filter, .body = comp->expr) : comp->expr; - ast_t *loop = WrapAST(ast, For, .vars = comp->vars, .iter = comp->iter, .body = body); - return add_closed_vars(closed_vars, enclosing_scope, env, loop); - } - - // List/Set/Table comprehension: - ast_t *body = comp->expr; - if (comp->filter) body = WrapAST(comp->expr, If, .condition = comp->filter, .body = body); - ast_t *loop = WrapAST(ast, For, .vars = comp->vars, .iter = comp->iter, .body = body); - add_closed_vars(closed_vars, enclosing_scope, env, loop); - break; - } - case Lambda: { - DeclareMatch(lambda, ast, Lambda); - env_t *lambda_scope = fresh_scope(env); - for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) - set_binding(lambda_scope, arg->name, get_arg_ast_type(env, arg), Texts("_$", arg->name)); - add_closed_vars(closed_vars, enclosing_scope, lambda_scope, lambda->body); - break; - } - case FunctionCall: { - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, FunctionCall)->fn); - for (arg_ast_t *arg = Match(ast, FunctionCall)->args; arg; arg = arg->next) - add_closed_vars(closed_vars, enclosing_scope, env, arg->value); - break; - } - case MethodCall: { - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, MethodCall)->self); - for (arg_ast_t *arg = Match(ast, MethodCall)->args; arg; arg = arg->next) - add_closed_vars(closed_vars, enclosing_scope, env, arg->value); - break; - } - case Block: { - env = fresh_scope(env); - for (ast_list_t *statement = Match(ast, Block)->statements; statement; statement = statement->next) - add_closed_vars(closed_vars, enclosing_scope, env, statement->ast); - break; - } - case For: { - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, For)->iter); - env_t *body_scope = for_scope(env, ast); - add_closed_vars(closed_vars, enclosing_scope, body_scope, Match(ast, For)->body); - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, For)->empty); - break; - } - case While: { - DeclareMatch(while_, ast, While); - add_closed_vars(closed_vars, enclosing_scope, env, while_->condition); - env_t *scope = fresh_scope(env); - add_closed_vars(closed_vars, enclosing_scope, scope, while_->body); - break; - } - case If: { - DeclareMatch(if_, ast, If); - ast_t *condition = if_->condition; - if (condition->tag == Declare) { - env_t *truthy_scope = fresh_scope(env); - bind_statement(truthy_scope, condition); - if (!Match(condition, Declare)->value) - code_err(condition, "This declared variable must have an initial value"); - add_closed_vars(closed_vars, enclosing_scope, env, Match(condition, Declare)->value); - ast_t *var = Match(condition, Declare)->var; - type_t *cond_t = get_type(truthy_scope, var); - if (cond_t->tag == OptionalType) { - set_binding(truthy_scope, Match(var, Var)->name, Match(cond_t, OptionalType)->type, EMPTY_TEXT); - } - add_closed_vars(closed_vars, enclosing_scope, truthy_scope, if_->body); - add_closed_vars(closed_vars, enclosing_scope, env, if_->else_body); - } else { - add_closed_vars(closed_vars, enclosing_scope, env, condition); - env_t *truthy_scope = env; - type_t *cond_t = get_type(env, condition); - if (condition->tag == Var && cond_t->tag == OptionalType) { - truthy_scope = fresh_scope(env); - set_binding(truthy_scope, Match(condition, Var)->name, Match(cond_t, OptionalType)->type, EMPTY_TEXT); - } - add_closed_vars(closed_vars, enclosing_scope, truthy_scope, if_->body); - add_closed_vars(closed_vars, enclosing_scope, env, if_->else_body); - } - break; - } - case When: { - DeclareMatch(when, ast, When); - add_closed_vars(closed_vars, enclosing_scope, env, when->subject); - type_t *subject_t = get_type(env, when->subject); - - if (subject_t->tag != EnumType) { - for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { - add_closed_vars(closed_vars, enclosing_scope, env, clause->pattern); - add_closed_vars(closed_vars, enclosing_scope, env, clause->body); - } - - if (when->else_body) add_closed_vars(closed_vars, enclosing_scope, env, when->else_body); - return; - } - - DeclareMatch(enum_t, subject_t, EnumType); - for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { - const char *clause_tag_name; - if (clause->pattern->tag == Var) clause_tag_name = Match(clause->pattern, Var)->name; - else if (clause->pattern->tag == FunctionCall && Match(clause->pattern, FunctionCall)->fn->tag == Var) - clause_tag_name = Match(Match(clause->pattern, FunctionCall)->fn, Var)->name; - else code_err(clause->pattern, "This is not a valid pattern for a ", type_to_str(subject_t), " enum"); - - type_t *tag_type = NULL; - for (tag_t *tag = enum_t->tags; tag; tag = tag->next) { - if (streq(tag->name, clause_tag_name)) { - tag_type = tag->type; - break; - } - } - assert(tag_type); - env_t *scope = when_clause_scope(env, subject_t, clause); - add_closed_vars(closed_vars, enclosing_scope, scope, clause->body); - } - if (when->else_body) add_closed_vars(closed_vars, enclosing_scope, env, when->else_body); - break; - } - case Repeat: { - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Repeat)->body); - break; - } - case Reduction: { - DeclareMatch(reduction, ast, Reduction); - static int64_t next_id = 1; - ast_t *item = FakeAST(Var, String("$it", next_id++)); - ast_t *loop = - FakeAST(For, .vars = new (ast_list_t, .ast = item), .iter = reduction->iter, .body = FakeAST(Pass)); - env_t *scope = for_scope(env, loop); - add_closed_vars(closed_vars, enclosing_scope, scope, reduction->key ? reduction->key : item); - break; - } - case Defer: { - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Defer)->body); - break; - } - case Return: { - ast_t *ret = Match(ast, Return)->value; - if (ret) add_closed_vars(closed_vars, enclosing_scope, env, ret); - break; - } - case Index: { - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Index)->indexed); - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Index)->index); - break; - } - case FieldAccess: { - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, FieldAccess)->fielded); - break; - } - case Optional: { - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Optional)->value); - break; - } - case NonOptional: { - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, NonOptional)->value); - break; - } - case DocTest: { - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, DocTest)->expr); - break; - } - case Assert: { - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Assert)->expr); - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Assert)->message); - break; - } - case Deserialize: { - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Deserialize)->value); - break; - } - case ExplicitlyTyped: { - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, ExplicitlyTyped)->ast); - break; - } - case Use: - case FunctionDef: - case ConvertDef: - case StructDef: - case EnumDef: - case LangDef: - case Extend: { - errx(1, "Definitions should not be reachable in a closure."); - } - default: break; - } -} - -static Table_t get_closed_vars(env_t *env, arg_ast_t *args, ast_t *block) { - env_t *body_scope = fresh_scope(env); - for (arg_ast_t *arg = args; arg; arg = arg->next) { - type_t *arg_type = get_arg_ast_type(env, arg); - set_binding(body_scope, arg->name, arg_type, Texts("_$", arg->name)); - } - - Table_t closed_vars = {}; - add_closed_vars(&closed_vars, env, body_scope, block); - return closed_vars; -} - Text_t compile_declaration(type_t *t, Text_t name) { if (t->tag == FunctionType) { DeclareMatch(fn, t, FunctionType); @@ -2236,94 +1953,7 @@ Text_t compile(env_t *env, ast_t *ast) { if (base->tag == TableEntry) return compile(env, WrapAST(ast, Table, .entries = new (ast_list_t, .ast = ast))); else return compile(env, WrapAST(ast, List, .items = new (ast_list_t, .ast = ast))); } - case Lambda: { - DeclareMatch(lambda, ast, Lambda); - Text_t name = namespace_name(env, env->namespace, Texts("lambda$", String(lambda->id))); - - env_t *body_scope = fresh_scope(env); - body_scope->deferred = NULL; - for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) { - type_t *arg_type = get_arg_ast_type(env, arg); - set_binding(body_scope, arg->name, arg_type, Texts("_$", arg->name)); - } - - type_t *ret_t = get_type(body_scope, lambda->body); - if (ret_t->tag == ReturnType) ret_t = Match(ret_t, ReturnType)->ret; - - if (lambda->ret_type) { - type_t *declared = parse_type_ast(env, lambda->ret_type); - if (can_promote(ret_t, declared)) ret_t = declared; - else - code_err(ast, "This function was declared to return a value of type ", type_to_str(declared), - ", but actually returns a value of type ", type_to_str(ret_t)); - } - - body_scope->fn_ret = ret_t; - - Table_t closed_vars = get_closed_vars(env, lambda->args, ast); - if (Table$length(closed_vars) > 0) { // Create a typedef for the lambda's closure userdata - Text_t def = Text("typedef struct {"); - for (int64_t i = 0; i < closed_vars.entries.length; i++) { - struct { - const char *name; - binding_t *b; - } *entry = closed_vars.entries.data + closed_vars.entries.stride * i; - if (has_stack_memory(entry->b->type)) - code_err(ast, "This function is holding onto a reference to ", type_to_str(entry->b->type), - " stack memory in the variable `", entry->name, - "`, but the function may outlive the stack memory"); - if (entry->b->type->tag == ModuleType) continue; - set_binding(body_scope, entry->name, entry->b->type, Texts("userdata->", entry->name)); - def = Texts(def, compile_declaration(entry->b->type, Text$from_str(entry->name)), "; "); - } - def = Texts(def, "} ", name, "$userdata_t;"); - env->code->local_typedefs = Texts(env->code->local_typedefs, def); - } - - Text_t code = Texts("static ", compile_type(ret_t), " ", name, "("); - for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) { - type_t *arg_type = get_arg_ast_type(env, arg); - code = Texts(code, compile_type(arg_type), " _$", arg->name, ", "); - } - - Text_t userdata; - if (Table$length(closed_vars) == 0) { - code = Texts(code, "void *_)"); - userdata = Text("NULL"); - } else { - userdata = Texts("new(", name, "$userdata_t"); - for (int64_t i = 0; i < closed_vars.entries.length; i++) { - struct { - const char *name; - binding_t *b; - } *entry = closed_vars.entries.data + closed_vars.entries.stride * i; - if (entry->b->type->tag == ModuleType) continue; - binding_t *b = get_binding(env, entry->name); - assert(b); - Text_t binding_code = b->code; - if (entry->b->type->tag == ListType) userdata = Texts(userdata, ", LIST_COPY(", binding_code, ")"); - else if (entry->b->type->tag == TableType || entry->b->type->tag == SetType) - userdata = Texts(userdata, ", TABLE_COPY(", binding_code, ")"); - else userdata = Texts(userdata, ", ", binding_code); - } - userdata = Texts(userdata, ")"); - code = Texts(code, name, "$userdata_t *userdata)"); - } - - Text_t body = EMPTY_TEXT; - for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next) { - if (stmt->next || ret_t->tag == VoidType || ret_t->tag == AbortType - || get_type(body_scope, stmt->ast)->tag == ReturnType) - body = Texts(body, compile_statement(body_scope, stmt->ast), "\n"); - else body = Texts(body, compile_statement(body_scope, FakeAST(Return, stmt->ast)), "\n"); - bind_statement(body_scope, stmt->ast); - } - if ((ret_t->tag == VoidType || ret_t->tag == AbortType) && body_scope->deferred) - body = Texts(body, compile_statement(body_scope, FakeAST(Return)), "\n"); - - env->code->lambdas = Texts(env->code->lambdas, code, " {\n", body, "\n}\n"); - return Texts("((Closure_t){", name, ", ", userdata, "})"); - } + case Lambda: return compile_lambda(env, ast); case MethodCall: { DeclareMatch(call, ast, MethodCall); type_t *self_t = get_type(env, call->self); diff --git a/src/compile/functions.c b/src/compile/functions.c index 1a2e5d3d..4015f0ca 100644 --- a/src/compile/functions.c +++ b/src/compile/functions.c @@ -1,6 +1,8 @@ +#include "functions.h" #include "../ast.h" #include "../compile.h" #include "../environment.h" +#include "../naming.h" #include "../stdlib/datatypes.h" #include "../stdlib/integers.h" #include "../stdlib/nums.h" @@ -195,3 +197,377 @@ Text_t compile_function_call(env_t *env, ast_t *ast) { code_err(call->fn, "This is not a function, it's a ", type_to_str(fn_t)); } } + +public +Text_t compile_lambda(env_t *env, ast_t *ast) { + DeclareMatch(lambda, ast, Lambda); + Text_t name = namespace_name(env, env->namespace, Texts("lambda$", String(lambda->id))); + + env_t *body_scope = fresh_scope(env); + body_scope->deferred = NULL; + for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) { + type_t *arg_type = get_arg_ast_type(env, arg); + set_binding(body_scope, arg->name, arg_type, Texts("_$", arg->name)); + } + + type_t *ret_t = get_type(body_scope, lambda->body); + if (ret_t->tag == ReturnType) ret_t = Match(ret_t, ReturnType)->ret; + + if (lambda->ret_type) { + type_t *declared = parse_type_ast(env, lambda->ret_type); + if (can_promote(ret_t, declared)) ret_t = declared; + else + code_err(ast, "This function was declared to return a value of type ", type_to_str(declared), + ", but actually returns a value of type ", type_to_str(ret_t)); + } + + body_scope->fn_ret = ret_t; + + Table_t closed_vars = get_closed_vars(env, lambda->args, ast); + if (Table$length(closed_vars) > 0) { // Create a typedef for the lambda's closure userdata + Text_t def = Text("typedef struct {"); + for (int64_t i = 0; i < closed_vars.entries.length; i++) { + struct { + const char *name; + binding_t *b; + } *entry = closed_vars.entries.data + closed_vars.entries.stride * i; + if (has_stack_memory(entry->b->type)) + code_err(ast, "This function is holding onto a reference to ", type_to_str(entry->b->type), + " stack memory in the variable `", entry->name, + "`, but the function may outlive the stack memory"); + if (entry->b->type->tag == ModuleType) continue; + set_binding(body_scope, entry->name, entry->b->type, Texts("userdata->", entry->name)); + def = Texts(def, compile_declaration(entry->b->type, Text$from_str(entry->name)), "; "); + } + def = Texts(def, "} ", name, "$userdata_t;"); + env->code->local_typedefs = Texts(env->code->local_typedefs, def); + } + + Text_t code = Texts("static ", compile_type(ret_t), " ", name, "("); + for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) { + type_t *arg_type = get_arg_ast_type(env, arg); + code = Texts(code, compile_type(arg_type), " _$", arg->name, ", "); + } + + Text_t userdata; + if (Table$length(closed_vars) == 0) { + code = Texts(code, "void *_)"); + userdata = Text("NULL"); + } else { + userdata = Texts("new(", name, "$userdata_t"); + for (int64_t i = 0; i < closed_vars.entries.length; i++) { + struct { + const char *name; + binding_t *b; + } *entry = closed_vars.entries.data + closed_vars.entries.stride * i; + if (entry->b->type->tag == ModuleType) continue; + binding_t *b = get_binding(env, entry->name); + assert(b); + Text_t binding_code = b->code; + if (entry->b->type->tag == ListType) userdata = Texts(userdata, ", LIST_COPY(", binding_code, ")"); + else if (entry->b->type->tag == TableType || entry->b->type->tag == SetType) + userdata = Texts(userdata, ", TABLE_COPY(", binding_code, ")"); + else userdata = Texts(userdata, ", ", binding_code); + } + userdata = Texts(userdata, ")"); + code = Texts(code, name, "$userdata_t *userdata)"); + } + + Text_t body = EMPTY_TEXT; + for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next) { + if (stmt->next || ret_t->tag == VoidType || ret_t->tag == AbortType + || get_type(body_scope, stmt->ast)->tag == ReturnType) + body = Texts(body, compile_statement(body_scope, stmt->ast), "\n"); + else body = Texts(body, compile_statement(body_scope, FakeAST(Return, stmt->ast)), "\n"); + bind_statement(body_scope, stmt->ast); + } + if ((ret_t->tag == VoidType || ret_t->tag == AbortType) && body_scope->deferred) + body = Texts(body, compile_statement(body_scope, FakeAST(Return)), "\n"); + + env->code->lambdas = Texts(env->code->lambdas, code, " {\n", body, "\n}\n"); + return Texts("((Closure_t){", name, ", ", userdata, "})"); +} + +static void add_closed_vars(Table_t *closed_vars, env_t *enclosing_scope, env_t *env, ast_t *ast) { + if (ast == NULL) return; + + switch (ast->tag) { + case Var: { + binding_t *b = get_binding(enclosing_scope, Match(ast, Var)->name); + if (b) { + binding_t *shadow = get_binding(env, Match(ast, Var)->name); + if (!shadow || shadow == b) Table$str_set(closed_vars, Match(ast, Var)->name, b); + } + break; + } + case TextJoin: { + for (ast_list_t *child = Match(ast, TextJoin)->children; child; child = child->next) + add_closed_vars(closed_vars, enclosing_scope, env, child->ast); + break; + } + case Declare: { + ast_t *value = Match(ast, Declare)->value; + add_closed_vars(closed_vars, enclosing_scope, env, value); + bind_statement(env, ast); + break; + } + case Assign: { + for (ast_list_t *target = Match(ast, Assign)->targets; target; target = target->next) + add_closed_vars(closed_vars, enclosing_scope, env, target->ast); + for (ast_list_t *value = Match(ast, Assign)->values; value; value = value->next) + add_closed_vars(closed_vars, enclosing_scope, env, value->ast); + break; + } + case BINOP_CASES: { + binary_operands_t binop = BINARY_OPERANDS(ast); + add_closed_vars(closed_vars, enclosing_scope, env, binop.lhs); + add_closed_vars(closed_vars, enclosing_scope, env, binop.rhs); + break; + } + case Not: + case Negative: + case HeapAllocate: + case StackReference: { + // UNSAFE: + ast_t *value = ast->__data.Not.value; + // END UNSAFE + add_closed_vars(closed_vars, enclosing_scope, env, value); + break; + } + case Min: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Min)->lhs); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Min)->rhs); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Min)->key); + break; + } + case Max: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Max)->lhs); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Max)->rhs); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Max)->key); + break; + } + case List: { + for (ast_list_t *item = Match(ast, List)->items; item; item = item->next) + add_closed_vars(closed_vars, enclosing_scope, env, item->ast); + break; + } + case Set: { + for (ast_list_t *item = Match(ast, Set)->items; item; item = item->next) + add_closed_vars(closed_vars, enclosing_scope, env, item->ast); + break; + } + case Table: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Table)->default_value); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Table)->fallback); + for (ast_list_t *entry = Match(ast, Table)->entries; entry; entry = entry->next) + add_closed_vars(closed_vars, enclosing_scope, env, entry->ast); + break; + } + case TableEntry: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, TableEntry)->key); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, TableEntry)->value); + break; + } + case Comprehension: { + DeclareMatch(comp, ast, Comprehension); + if (comp->expr->tag == Comprehension) { // Nested comprehension + ast_t *body = comp->filter ? WrapAST(ast, If, .condition = comp->filter, .body = comp->expr) : comp->expr; + ast_t *loop = WrapAST(ast, For, .vars = comp->vars, .iter = comp->iter, .body = body); + return add_closed_vars(closed_vars, enclosing_scope, env, loop); + } + + // List/Set/Table comprehension: + ast_t *body = comp->expr; + if (comp->filter) body = WrapAST(comp->expr, If, .condition = comp->filter, .body = body); + ast_t *loop = WrapAST(ast, For, .vars = comp->vars, .iter = comp->iter, .body = body); + add_closed_vars(closed_vars, enclosing_scope, env, loop); + break; + } + case Lambda: { + DeclareMatch(lambda, ast, Lambda); + env_t *lambda_scope = fresh_scope(env); + for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) + set_binding(lambda_scope, arg->name, get_arg_ast_type(env, arg), Texts("_$", arg->name)); + add_closed_vars(closed_vars, enclosing_scope, lambda_scope, lambda->body); + break; + } + case FunctionCall: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, FunctionCall)->fn); + for (arg_ast_t *arg = Match(ast, FunctionCall)->args; arg; arg = arg->next) + add_closed_vars(closed_vars, enclosing_scope, env, arg->value); + break; + } + case MethodCall: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, MethodCall)->self); + for (arg_ast_t *arg = Match(ast, MethodCall)->args; arg; arg = arg->next) + add_closed_vars(closed_vars, enclosing_scope, env, arg->value); + break; + } + case Block: { + env = fresh_scope(env); + for (ast_list_t *statement = Match(ast, Block)->statements; statement; statement = statement->next) + add_closed_vars(closed_vars, enclosing_scope, env, statement->ast); + break; + } + case For: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, For)->iter); + env_t *body_scope = for_scope(env, ast); + add_closed_vars(closed_vars, enclosing_scope, body_scope, Match(ast, For)->body); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, For)->empty); + break; + } + case While: { + DeclareMatch(while_, ast, While); + add_closed_vars(closed_vars, enclosing_scope, env, while_->condition); + env_t *scope = fresh_scope(env); + add_closed_vars(closed_vars, enclosing_scope, scope, while_->body); + break; + } + case If: { + DeclareMatch(if_, ast, If); + ast_t *condition = if_->condition; + if (condition->tag == Declare) { + env_t *truthy_scope = fresh_scope(env); + bind_statement(truthy_scope, condition); + if (!Match(condition, Declare)->value) + code_err(condition, "This declared variable must have an initial value"); + add_closed_vars(closed_vars, enclosing_scope, env, Match(condition, Declare)->value); + ast_t *var = Match(condition, Declare)->var; + type_t *cond_t = get_type(truthy_scope, var); + if (cond_t->tag == OptionalType) { + set_binding(truthy_scope, Match(var, Var)->name, Match(cond_t, OptionalType)->type, EMPTY_TEXT); + } + add_closed_vars(closed_vars, enclosing_scope, truthy_scope, if_->body); + add_closed_vars(closed_vars, enclosing_scope, env, if_->else_body); + } else { + add_closed_vars(closed_vars, enclosing_scope, env, condition); + env_t *truthy_scope = env; + type_t *cond_t = get_type(env, condition); + if (condition->tag == Var && cond_t->tag == OptionalType) { + truthy_scope = fresh_scope(env); + set_binding(truthy_scope, Match(condition, Var)->name, Match(cond_t, OptionalType)->type, EMPTY_TEXT); + } + add_closed_vars(closed_vars, enclosing_scope, truthy_scope, if_->body); + add_closed_vars(closed_vars, enclosing_scope, env, if_->else_body); + } + break; + } + case When: { + DeclareMatch(when, ast, When); + add_closed_vars(closed_vars, enclosing_scope, env, when->subject); + type_t *subject_t = get_type(env, when->subject); + + if (subject_t->tag != EnumType) { + for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { + add_closed_vars(closed_vars, enclosing_scope, env, clause->pattern); + add_closed_vars(closed_vars, enclosing_scope, env, clause->body); + } + + if (when->else_body) add_closed_vars(closed_vars, enclosing_scope, env, when->else_body); + return; + } + + DeclareMatch(enum_t, subject_t, EnumType); + for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { + const char *clause_tag_name; + if (clause->pattern->tag == Var) clause_tag_name = Match(clause->pattern, Var)->name; + else if (clause->pattern->tag == FunctionCall && Match(clause->pattern, FunctionCall)->fn->tag == Var) + clause_tag_name = Match(Match(clause->pattern, FunctionCall)->fn, Var)->name; + else code_err(clause->pattern, "This is not a valid pattern for a ", type_to_str(subject_t), " enum"); + + type_t *tag_type = NULL; + for (tag_t *tag = enum_t->tags; tag; tag = tag->next) { + if (streq(tag->name, clause_tag_name)) { + tag_type = tag->type; + break; + } + } + assert(tag_type); + env_t *scope = when_clause_scope(env, subject_t, clause); + add_closed_vars(closed_vars, enclosing_scope, scope, clause->body); + } + if (when->else_body) add_closed_vars(closed_vars, enclosing_scope, env, when->else_body); + break; + } + case Repeat: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Repeat)->body); + break; + } + case Reduction: { + DeclareMatch(reduction, ast, Reduction); + static int64_t next_id = 1; + ast_t *item = FakeAST(Var, String("$it", next_id++)); + ast_t *loop = + FakeAST(For, .vars = new (ast_list_t, .ast = item), .iter = reduction->iter, .body = FakeAST(Pass)); + env_t *scope = for_scope(env, loop); + add_closed_vars(closed_vars, enclosing_scope, scope, reduction->key ? reduction->key : item); + break; + } + case Defer: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Defer)->body); + break; + } + case Return: { + ast_t *ret = Match(ast, Return)->value; + if (ret) add_closed_vars(closed_vars, enclosing_scope, env, ret); + break; + } + case Index: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Index)->indexed); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Index)->index); + break; + } + case FieldAccess: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, FieldAccess)->fielded); + break; + } + case Optional: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Optional)->value); + break; + } + case NonOptional: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, NonOptional)->value); + break; + } + case DocTest: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, DocTest)->expr); + break; + } + case Assert: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Assert)->expr); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Assert)->message); + break; + } + case Deserialize: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Deserialize)->value); + break; + } + case ExplicitlyTyped: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, ExplicitlyTyped)->ast); + break; + } + case Use: + case FunctionDef: + case ConvertDef: + case StructDef: + case EnumDef: + case LangDef: + case Extend: { + errx(1, "Definitions should not be reachable in a closure."); + } + default: break; + } +} + +public +Table_t get_closed_vars(env_t *env, arg_ast_t *args, ast_t *block) { + env_t *body_scope = fresh_scope(env); + for (arg_ast_t *arg = args; arg; arg = arg->next) { + type_t *arg_type = get_arg_ast_type(env, arg); + set_binding(body_scope, arg->name, arg_type, Texts("_$", arg->name)); + } + + Table_t closed_vars = {}; + add_closed_vars(&closed_vars, env, body_scope, block); + return closed_vars; +} diff --git a/src/compile/functions.h b/src/compile/functions.h index fdab495f..f7edd2aa 100644 --- a/src/compile/functions.h +++ b/src/compile/functions.h @@ -5,3 +5,5 @@ Text_t compile_function_call(env_t *env, ast_t *ast); Text_t compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t *call_args); +Text_t compile_lambda(env_t *env, ast_t *ast); +Table_t get_closed_vars(env_t *env, arg_ast_t *args, ast_t *block); |
