From d3f14cf53cf857b90184900a726e3ee0875dea80 Mon Sep 17 00:00:00 2001 From: Bruce Hill Date: Sun, 14 Jul 2024 14:13:23 -0400 Subject: Support nested lambda closures --- compile.c | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) (limited to 'compile.c') diff --git a/compile.c b/compile.c index a49ab29c..938c1c0c 100644 --- a/compile.c +++ b/compile.c @@ -62,13 +62,15 @@ static table_t *get_closed_vars(env_t *env, ast_t *lambda_ast) { auto lambda = Match(lambda_ast, Lambda); env_t *body_scope = fresh_scope(env); + body_scope->code = new(compilation_unit_t); // Don't put any code in the headers or anything for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) { type_t *arg_type = get_arg_ast_type(env, arg); set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=CORD_cat("$", arg->name))); } fn_ctx_t fn_ctx = (fn_ctx_t){ - .closure_scope=body_scope->locals->fallback, + .parent=env->fn_ctx, + .closure_scope=env->locals, .closed_vars=new(table_t), }; body_scope->fn_ctx = &fn_ctx; @@ -596,6 +598,7 @@ CORD compile_statement(env_t *env, ast_t *ast) } fn_ctx_t fn_ctx = (fn_ctx_t){ + .parent=NULL, .return_type=ret_t, .closure_scope=NULL, .closed_vars=NULL, @@ -717,6 +720,9 @@ CORD compile_statement(env_t *env, ast_t *ast) code = CORD_all( code, compile_declaration(entry->b->type, defer_name), " = ", entry->b->code, ";\n"); set_binding(defer_env, entry->name, new(binding_t, .type=entry->b->type, .code=defer_name)); + + if (env->fn_ctx->closed_vars) + Table$str_set(env->fn_ctx->closed_vars, entry->name, entry->b); } env->deferred = new(deferral_t, .defer_env=defer_env, .block=body, .next=env->deferred); return code; @@ -1738,7 +1744,8 @@ CORD compile(env_t *env, ast_t *ast) } fn_ctx_t fn_ctx = (fn_ctx_t){ - .closure_scope=body_scope->locals->fallback, + .parent=env->fn_ctx, + .closure_scope=env->locals, .closed_vars=new(table_t), }; body_scope->fn_ctx = &fn_ctx; @@ -1749,6 +1756,14 @@ CORD compile(env_t *env, ast_t *ast) ret_t = Match(ret_t, ReturnType)->ret; fn_ctx.return_type = ret_t; + if (env->fn_ctx->closed_vars) { + for (int64_t i = 1; i <= Table$length(*env->fn_ctx->closed_vars); i++) { + struct { const char *name; binding_t *b; } *entry = Table$entry(*env->fn_ctx->closed_vars, i); + set_binding(body_scope, entry->name, new(binding_t, .type=entry->b->type, .code=CORD_cat("userdata->", entry->name))); + Table$str_set(fn_ctx.closed_vars, entry->name, entry->b); + } + } + CORD code = CORD_all("static ", compile_type(ret_t), " ", name, "("); for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) { type_t *arg_type = get_arg_ast_type(env, arg); @@ -1769,7 +1784,7 @@ CORD compile(env_t *env, ast_t *ast) struct { const char *name; binding_t *b; } *entry = Table$entry(*closed_vars, i); if (entry->b->type->tag == ModuleType) continue; - userdata = CORD_all(userdata, ", ", entry->b->code); + userdata = CORD_all(userdata, ", ", get_binding(env, entry->name)->code); } userdata = CORD_all(userdata, ")"); code = CORD_all(code, name, "$userdata_t *userdata)"); -- cgit v1.2.3