diff options
| -rw-r--r-- | compile.c | 21 | ||||
| -rw-r--r-- | environment.c | 15 | ||||
| -rw-r--r-- | environment.h | 3 | ||||
| -rw-r--r-- | test/lambdas.tm | 10 |
4 files changed, 40 insertions, 9 deletions
@@ -62,13 +62,15 @@ static table_t *get_closed_vars(env_t *env, ast_t *lambda_ast) { auto lambda = Match(lambda_ast, Lambda); env_t *body_scope = fresh_scope(env); + body_scope->code = new(compilation_unit_t); // Don't put any code in the headers or anything for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) { type_t *arg_type = get_arg_ast_type(env, arg); set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=CORD_cat("$", arg->name))); } fn_ctx_t fn_ctx = (fn_ctx_t){ - .closure_scope=body_scope->locals->fallback, + .parent=env->fn_ctx, + .closure_scope=env->locals, .closed_vars=new(table_t), }; body_scope->fn_ctx = &fn_ctx; @@ -596,6 +598,7 @@ CORD compile_statement(env_t *env, ast_t *ast) } fn_ctx_t fn_ctx = (fn_ctx_t){ + .parent=NULL, .return_type=ret_t, .closure_scope=NULL, .closed_vars=NULL, @@ -717,6 +720,9 @@ CORD compile_statement(env_t *env, ast_t *ast) code = CORD_all( code, compile_declaration(entry->b->type, defer_name), " = ", entry->b->code, ";\n"); set_binding(defer_env, entry->name, new(binding_t, .type=entry->b->type, .code=defer_name)); + + if (env->fn_ctx->closed_vars) + Table$str_set(env->fn_ctx->closed_vars, entry->name, entry->b); } env->deferred = new(deferral_t, .defer_env=defer_env, .block=body, .next=env->deferred); return code; @@ -1738,7 +1744,8 @@ CORD compile(env_t *env, ast_t *ast) } fn_ctx_t fn_ctx = (fn_ctx_t){ - .closure_scope=body_scope->locals->fallback, + .parent=env->fn_ctx, + .closure_scope=env->locals, .closed_vars=new(table_t), }; body_scope->fn_ctx = &fn_ctx; @@ -1749,6 +1756,14 @@ CORD compile(env_t *env, ast_t *ast) ret_t = Match(ret_t, ReturnType)->ret; fn_ctx.return_type = ret_t; + if (env->fn_ctx->closed_vars) { + for (int64_t i = 1; i <= Table$length(*env->fn_ctx->closed_vars); i++) { + struct { const char *name; binding_t *b; } *entry = Table$entry(*env->fn_ctx->closed_vars, i); + set_binding(body_scope, entry->name, new(binding_t, .type=entry->b->type, .code=CORD_cat("userdata->", entry->name))); + Table$str_set(fn_ctx.closed_vars, entry->name, entry->b); + } + } + CORD code = CORD_all("static ", compile_type(ret_t), " ", name, "("); for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) { type_t *arg_type = get_arg_ast_type(env, arg); @@ -1769,7 +1784,7 @@ CORD compile(env_t *env, ast_t *ast) struct { const char *name; binding_t *b; } *entry = Table$entry(*closed_vars, i); if (entry->b->type->tag == ModuleType) continue; - userdata = CORD_all(userdata, ", ", entry->b->code); + userdata = CORD_all(userdata, ", ", get_binding(env, entry->name)->code); } userdata = CORD_all(userdata, ")"); code = CORD_all(code, name, "$userdata_t *userdata)"); diff --git a/environment.c b/environment.c index 65130c37..1fe698d2 100644 --- a/environment.c +++ b/environment.c @@ -372,11 +372,16 @@ env_t *namespace_env(env_t *env, const char *namespace_name) binding_t *get_binding(env_t *env, const char *name) { binding_t *b = Table$str_get(*env->locals, name); - if (!b && env->fn_ctx && env->fn_ctx->closure_scope) { - b = Table$str_get(*env->fn_ctx->closure_scope, name); - if (b) { - Table$str_set(env->fn_ctx->closed_vars, name, b); - return new(binding_t, .type=b->type, .code=CORD_all("userdata->", name)); + if (!b) { + for (fn_ctx_t *fn_ctx = env->fn_ctx; fn_ctx; fn_ctx = fn_ctx->parent) { + if (!fn_ctx->closure_scope) continue; + b = Table$str_get(*fn_ctx->closure_scope, name); + if (b) { + Table$str_set(env->fn_ctx->closed_vars, name, b); + binding_t *b2 = new(binding_t, .type=b->type, .code=CORD_all("userdata->", name)); + Table$str_set(env->locals, name, b2); + return b2; + } } } return b; diff --git a/environment.h b/environment.h index be87b857..80723b0f 100644 --- a/environment.h +++ b/environment.h @@ -14,7 +14,8 @@ typedef struct { CORD typeinfos; } compilation_unit_t; -typedef struct { +typedef struct fn_ctx_s { + struct fn_ctx_s *parent; type_t *return_type; table_t *closure_scope; table_t *closed_vars; diff --git a/test/lambdas.tm b/test/lambdas.tm index 15a8f23b..cb5bf3d4 100644 --- a/test/lambdas.tm +++ b/test/lambdas.tm @@ -30,3 +30,13 @@ func main(): >> abs100 := mul_func(100, Int.abs) >> abs100(-5) = 500 + + // Test nested lambdas: + outer := "Hello" + fn := func(): + return func(): + return func(): + defer: |{outer} + return outer + >> fn()()() + = "Hello" |
