diff options
Diffstat (limited to 'src/compile/functions.c')
| -rw-r--r-- | src/compile/functions.c | 376 |
1 files changed, 376 insertions, 0 deletions
diff --git a/src/compile/functions.c b/src/compile/functions.c index 1a2e5d3d..4015f0ca 100644 --- a/src/compile/functions.c +++ b/src/compile/functions.c @@ -1,6 +1,8 @@ +#include "functions.h" #include "../ast.h" #include "../compile.h" #include "../environment.h" +#include "../naming.h" #include "../stdlib/datatypes.h" #include "../stdlib/integers.h" #include "../stdlib/nums.h" @@ -195,3 +197,377 @@ Text_t compile_function_call(env_t *env, ast_t *ast) { code_err(call->fn, "This is not a function, it's a ", type_to_str(fn_t)); } } + +public +Text_t compile_lambda(env_t *env, ast_t *ast) { + DeclareMatch(lambda, ast, Lambda); + Text_t name = namespace_name(env, env->namespace, Texts("lambda$", String(lambda->id))); + + env_t *body_scope = fresh_scope(env); + body_scope->deferred = NULL; + for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) { + type_t *arg_type = get_arg_ast_type(env, arg); + set_binding(body_scope, arg->name, arg_type, Texts("_$", arg->name)); + } + + type_t *ret_t = get_type(body_scope, lambda->body); + if (ret_t->tag == ReturnType) ret_t = Match(ret_t, ReturnType)->ret; + + if (lambda->ret_type) { + type_t *declared = parse_type_ast(env, lambda->ret_type); + if (can_promote(ret_t, declared)) ret_t = declared; + else + code_err(ast, "This function was declared to return a value of type ", type_to_str(declared), + ", but actually returns a value of type ", type_to_str(ret_t)); + } + + body_scope->fn_ret = ret_t; + + Table_t closed_vars = get_closed_vars(env, lambda->args, ast); + if (Table$length(closed_vars) > 0) { // Create a typedef for the lambda's closure userdata + Text_t def = Text("typedef struct {"); + for (int64_t i = 0; i < closed_vars.entries.length; i++) { + struct { + const char *name; + binding_t *b; + } *entry = closed_vars.entries.data + closed_vars.entries.stride * i; + if (has_stack_memory(entry->b->type)) + code_err(ast, "This function is holding onto a reference to ", type_to_str(entry->b->type), + " stack memory in the variable `", entry->name, + "`, but the function may outlive the stack memory"); + if (entry->b->type->tag == ModuleType) continue; + set_binding(body_scope, entry->name, entry->b->type, Texts("userdata->", entry->name)); + def = Texts(def, compile_declaration(entry->b->type, Text$from_str(entry->name)), "; "); + } + def = Texts(def, "} ", name, "$userdata_t;"); + env->code->local_typedefs = Texts(env->code->local_typedefs, def); + } + + Text_t code = Texts("static ", compile_type(ret_t), " ", name, "("); + for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) { + type_t *arg_type = get_arg_ast_type(env, arg); + code = Texts(code, compile_type(arg_type), " _$", arg->name, ", "); + } + + Text_t userdata; + if (Table$length(closed_vars) == 0) { + code = Texts(code, "void *_)"); + userdata = Text("NULL"); + } else { + userdata = Texts("new(", name, "$userdata_t"); + for (int64_t i = 0; i < closed_vars.entries.length; i++) { + struct { + const char *name; + binding_t *b; + } *entry = closed_vars.entries.data + closed_vars.entries.stride * i; + if (entry->b->type->tag == ModuleType) continue; + binding_t *b = get_binding(env, entry->name); + assert(b); + Text_t binding_code = b->code; + if (entry->b->type->tag == ListType) userdata = Texts(userdata, ", LIST_COPY(", binding_code, ")"); + else if (entry->b->type->tag == TableType || entry->b->type->tag == SetType) + userdata = Texts(userdata, ", TABLE_COPY(", binding_code, ")"); + else userdata = Texts(userdata, ", ", binding_code); + } + userdata = Texts(userdata, ")"); + code = Texts(code, name, "$userdata_t *userdata)"); + } + + Text_t body = EMPTY_TEXT; + for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next) { + if (stmt->next || ret_t->tag == VoidType || ret_t->tag == AbortType + || get_type(body_scope, stmt->ast)->tag == ReturnType) + body = Texts(body, compile_statement(body_scope, stmt->ast), "\n"); + else body = Texts(body, compile_statement(body_scope, FakeAST(Return, stmt->ast)), "\n"); + bind_statement(body_scope, stmt->ast); + } + if ((ret_t->tag == VoidType || ret_t->tag == AbortType) && body_scope->deferred) + body = Texts(body, compile_statement(body_scope, FakeAST(Return)), "\n"); + + env->code->lambdas = Texts(env->code->lambdas, code, " {\n", body, "\n}\n"); + return Texts("((Closure_t){", name, ", ", userdata, "})"); +} + +static void add_closed_vars(Table_t *closed_vars, env_t *enclosing_scope, env_t *env, ast_t *ast) { + if (ast == NULL) return; + + switch (ast->tag) { + case Var: { + binding_t *b = get_binding(enclosing_scope, Match(ast, Var)->name); + if (b) { + binding_t *shadow = get_binding(env, Match(ast, Var)->name); + if (!shadow || shadow == b) Table$str_set(closed_vars, Match(ast, Var)->name, b); + } + break; + } + case TextJoin: { + for (ast_list_t *child = Match(ast, TextJoin)->children; child; child = child->next) + add_closed_vars(closed_vars, enclosing_scope, env, child->ast); + break; + } + case Declare: { + ast_t *value = Match(ast, Declare)->value; + add_closed_vars(closed_vars, enclosing_scope, env, value); + bind_statement(env, ast); + break; + } + case Assign: { + for (ast_list_t *target = Match(ast, Assign)->targets; target; target = target->next) + add_closed_vars(closed_vars, enclosing_scope, env, target->ast); + for (ast_list_t *value = Match(ast, Assign)->values; value; value = value->next) + add_closed_vars(closed_vars, enclosing_scope, env, value->ast); + break; + } + case BINOP_CASES: { + binary_operands_t binop = BINARY_OPERANDS(ast); + add_closed_vars(closed_vars, enclosing_scope, env, binop.lhs); + add_closed_vars(closed_vars, enclosing_scope, env, binop.rhs); + break; + } + case Not: + case Negative: + case HeapAllocate: + case StackReference: { + // UNSAFE: + ast_t *value = ast->__data.Not.value; + // END UNSAFE + add_closed_vars(closed_vars, enclosing_scope, env, value); + break; + } + case Min: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Min)->lhs); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Min)->rhs); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Min)->key); + break; + } + case Max: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Max)->lhs); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Max)->rhs); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Max)->key); + break; + } + case List: { + for (ast_list_t *item = Match(ast, List)->items; item; item = item->next) + add_closed_vars(closed_vars, enclosing_scope, env, item->ast); + break; + } + case Set: { + for (ast_list_t *item = Match(ast, Set)->items; item; item = item->next) + add_closed_vars(closed_vars, enclosing_scope, env, item->ast); + break; + } + case Table: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Table)->default_value); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Table)->fallback); + for (ast_list_t *entry = Match(ast, Table)->entries; entry; entry = entry->next) + add_closed_vars(closed_vars, enclosing_scope, env, entry->ast); + break; + } + case TableEntry: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, TableEntry)->key); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, TableEntry)->value); + break; + } + case Comprehension: { + DeclareMatch(comp, ast, Comprehension); + if (comp->expr->tag == Comprehension) { // Nested comprehension + ast_t *body = comp->filter ? WrapAST(ast, If, .condition = comp->filter, .body = comp->expr) : comp->expr; + ast_t *loop = WrapAST(ast, For, .vars = comp->vars, .iter = comp->iter, .body = body); + return add_closed_vars(closed_vars, enclosing_scope, env, loop); + } + + // List/Set/Table comprehension: + ast_t *body = comp->expr; + if (comp->filter) body = WrapAST(comp->expr, If, .condition = comp->filter, .body = body); + ast_t *loop = WrapAST(ast, For, .vars = comp->vars, .iter = comp->iter, .body = body); + add_closed_vars(closed_vars, enclosing_scope, env, loop); + break; + } + case Lambda: { + DeclareMatch(lambda, ast, Lambda); + env_t *lambda_scope = fresh_scope(env); + for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) + set_binding(lambda_scope, arg->name, get_arg_ast_type(env, arg), Texts("_$", arg->name)); + add_closed_vars(closed_vars, enclosing_scope, lambda_scope, lambda->body); + break; + } + case FunctionCall: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, FunctionCall)->fn); + for (arg_ast_t *arg = Match(ast, FunctionCall)->args; arg; arg = arg->next) + add_closed_vars(closed_vars, enclosing_scope, env, arg->value); + break; + } + case MethodCall: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, MethodCall)->self); + for (arg_ast_t *arg = Match(ast, MethodCall)->args; arg; arg = arg->next) + add_closed_vars(closed_vars, enclosing_scope, env, arg->value); + break; + } + case Block: { + env = fresh_scope(env); + for (ast_list_t *statement = Match(ast, Block)->statements; statement; statement = statement->next) + add_closed_vars(closed_vars, enclosing_scope, env, statement->ast); + break; + } + case For: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, For)->iter); + env_t *body_scope = for_scope(env, ast); + add_closed_vars(closed_vars, enclosing_scope, body_scope, Match(ast, For)->body); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, For)->empty); + break; + } + case While: { + DeclareMatch(while_, ast, While); + add_closed_vars(closed_vars, enclosing_scope, env, while_->condition); + env_t *scope = fresh_scope(env); + add_closed_vars(closed_vars, enclosing_scope, scope, while_->body); + break; + } + case If: { + DeclareMatch(if_, ast, If); + ast_t *condition = if_->condition; + if (condition->tag == Declare) { + env_t *truthy_scope = fresh_scope(env); + bind_statement(truthy_scope, condition); + if (!Match(condition, Declare)->value) + code_err(condition, "This declared variable must have an initial value"); + add_closed_vars(closed_vars, enclosing_scope, env, Match(condition, Declare)->value); + ast_t *var = Match(condition, Declare)->var; + type_t *cond_t = get_type(truthy_scope, var); + if (cond_t->tag == OptionalType) { + set_binding(truthy_scope, Match(var, Var)->name, Match(cond_t, OptionalType)->type, EMPTY_TEXT); + } + add_closed_vars(closed_vars, enclosing_scope, truthy_scope, if_->body); + add_closed_vars(closed_vars, enclosing_scope, env, if_->else_body); + } else { + add_closed_vars(closed_vars, enclosing_scope, env, condition); + env_t *truthy_scope = env; + type_t *cond_t = get_type(env, condition); + if (condition->tag == Var && cond_t->tag == OptionalType) { + truthy_scope = fresh_scope(env); + set_binding(truthy_scope, Match(condition, Var)->name, Match(cond_t, OptionalType)->type, EMPTY_TEXT); + } + add_closed_vars(closed_vars, enclosing_scope, truthy_scope, if_->body); + add_closed_vars(closed_vars, enclosing_scope, env, if_->else_body); + } + break; + } + case When: { + DeclareMatch(when, ast, When); + add_closed_vars(closed_vars, enclosing_scope, env, when->subject); + type_t *subject_t = get_type(env, when->subject); + + if (subject_t->tag != EnumType) { + for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { + add_closed_vars(closed_vars, enclosing_scope, env, clause->pattern); + add_closed_vars(closed_vars, enclosing_scope, env, clause->body); + } + + if (when->else_body) add_closed_vars(closed_vars, enclosing_scope, env, when->else_body); + return; + } + + DeclareMatch(enum_t, subject_t, EnumType); + for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { + const char *clause_tag_name; + if (clause->pattern->tag == Var) clause_tag_name = Match(clause->pattern, Var)->name; + else if (clause->pattern->tag == FunctionCall && Match(clause->pattern, FunctionCall)->fn->tag == Var) + clause_tag_name = Match(Match(clause->pattern, FunctionCall)->fn, Var)->name; + else code_err(clause->pattern, "This is not a valid pattern for a ", type_to_str(subject_t), " enum"); + + type_t *tag_type = NULL; + for (tag_t *tag = enum_t->tags; tag; tag = tag->next) { + if (streq(tag->name, clause_tag_name)) { + tag_type = tag->type; + break; + } + } + assert(tag_type); + env_t *scope = when_clause_scope(env, subject_t, clause); + add_closed_vars(closed_vars, enclosing_scope, scope, clause->body); + } + if (when->else_body) add_closed_vars(closed_vars, enclosing_scope, env, when->else_body); + break; + } + case Repeat: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Repeat)->body); + break; + } + case Reduction: { + DeclareMatch(reduction, ast, Reduction); + static int64_t next_id = 1; + ast_t *item = FakeAST(Var, String("$it", next_id++)); + ast_t *loop = + FakeAST(For, .vars = new (ast_list_t, .ast = item), .iter = reduction->iter, .body = FakeAST(Pass)); + env_t *scope = for_scope(env, loop); + add_closed_vars(closed_vars, enclosing_scope, scope, reduction->key ? reduction->key : item); + break; + } + case Defer: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Defer)->body); + break; + } + case Return: { + ast_t *ret = Match(ast, Return)->value; + if (ret) add_closed_vars(closed_vars, enclosing_scope, env, ret); + break; + } + case Index: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Index)->indexed); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Index)->index); + break; + } + case FieldAccess: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, FieldAccess)->fielded); + break; + } + case Optional: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Optional)->value); + break; + } + case NonOptional: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, NonOptional)->value); + break; + } + case DocTest: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, DocTest)->expr); + break; + } + case Assert: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Assert)->expr); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Assert)->message); + break; + } + case Deserialize: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Deserialize)->value); + break; + } + case ExplicitlyTyped: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, ExplicitlyTyped)->ast); + break; + } + case Use: + case FunctionDef: + case ConvertDef: + case StructDef: + case EnumDef: + case LangDef: + case Extend: { + errx(1, "Definitions should not be reachable in a closure."); + } + default: break; + } +} + +public +Table_t get_closed_vars(env_t *env, arg_ast_t *args, ast_t *block) { + env_t *body_scope = fresh_scope(env); + for (arg_ast_t *arg = args; arg; arg = arg->next) { + type_t *arg_type = get_arg_ast_type(env, arg); + set_binding(body_scope, arg->name, arg_type, Texts("_$", arg->name)); + } + + Table_t closed_vars = {}; + add_closed_vars(&closed_vars, env, body_scope, block); + return closed_vars; +} |
