aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compile.c21
-rw-r--r--environment.c15
-rw-r--r--environment.h3
-rw-r--r--test/lambdas.tm10
4 files changed, 40 insertions, 9 deletions
diff --git a/compile.c b/compile.c
index a49ab29c..938c1c0c 100644
--- a/compile.c
+++ b/compile.c
@@ -62,13 +62,15 @@ static table_t *get_closed_vars(env_t *env, ast_t *lambda_ast)
{
auto lambda = Match(lambda_ast, Lambda);
env_t *body_scope = fresh_scope(env);
+ body_scope->code = new(compilation_unit_t); // Don't put any code in the headers or anything
for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) {
type_t *arg_type = get_arg_ast_type(env, arg);
set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=CORD_cat("$", arg->name)));
}
fn_ctx_t fn_ctx = (fn_ctx_t){
- .closure_scope=body_scope->locals->fallback,
+ .parent=env->fn_ctx,
+ .closure_scope=env->locals,
.closed_vars=new(table_t),
};
body_scope->fn_ctx = &fn_ctx;
@@ -596,6 +598,7 @@ CORD compile_statement(env_t *env, ast_t *ast)
}
fn_ctx_t fn_ctx = (fn_ctx_t){
+ .parent=NULL,
.return_type=ret_t,
.closure_scope=NULL,
.closed_vars=NULL,
@@ -717,6 +720,9 @@ CORD compile_statement(env_t *env, ast_t *ast)
code = CORD_all(
code, compile_declaration(entry->b->type, defer_name), " = ", entry->b->code, ";\n");
set_binding(defer_env, entry->name, new(binding_t, .type=entry->b->type, .code=defer_name));
+
+ if (env->fn_ctx->closed_vars)
+ Table$str_set(env->fn_ctx->closed_vars, entry->name, entry->b);
}
env->deferred = new(deferral_t, .defer_env=defer_env, .block=body, .next=env->deferred);
return code;
@@ -1738,7 +1744,8 @@ CORD compile(env_t *env, ast_t *ast)
}
fn_ctx_t fn_ctx = (fn_ctx_t){
- .closure_scope=body_scope->locals->fallback,
+ .parent=env->fn_ctx,
+ .closure_scope=env->locals,
.closed_vars=new(table_t),
};
body_scope->fn_ctx = &fn_ctx;
@@ -1749,6 +1756,14 @@ CORD compile(env_t *env, ast_t *ast)
ret_t = Match(ret_t, ReturnType)->ret;
fn_ctx.return_type = ret_t;
+ if (env->fn_ctx->closed_vars) {
+ for (int64_t i = 1; i <= Table$length(*env->fn_ctx->closed_vars); i++) {
+ struct { const char *name; binding_t *b; } *entry = Table$entry(*env->fn_ctx->closed_vars, i);
+ set_binding(body_scope, entry->name, new(binding_t, .type=entry->b->type, .code=CORD_cat("userdata->", entry->name)));
+ Table$str_set(fn_ctx.closed_vars, entry->name, entry->b);
+ }
+ }
+
CORD code = CORD_all("static ", compile_type(ret_t), " ", name, "(");
for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) {
type_t *arg_type = get_arg_ast_type(env, arg);
@@ -1769,7 +1784,7 @@ CORD compile(env_t *env, ast_t *ast)
struct { const char *name; binding_t *b; } *entry = Table$entry(*closed_vars, i);
if (entry->b->type->tag == ModuleType)
continue;
- userdata = CORD_all(userdata, ", ", entry->b->code);
+ userdata = CORD_all(userdata, ", ", get_binding(env, entry->name)->code);
}
userdata = CORD_all(userdata, ")");
code = CORD_all(code, name, "$userdata_t *userdata)");
diff --git a/environment.c b/environment.c
index 65130c37..1fe698d2 100644
--- a/environment.c
+++ b/environment.c
@@ -372,11 +372,16 @@ env_t *namespace_env(env_t *env, const char *namespace_name)
binding_t *get_binding(env_t *env, const char *name)
{
binding_t *b = Table$str_get(*env->locals, name);
- if (!b && env->fn_ctx && env->fn_ctx->closure_scope) {
- b = Table$str_get(*env->fn_ctx->closure_scope, name);
- if (b) {
- Table$str_set(env->fn_ctx->closed_vars, name, b);
- return new(binding_t, .type=b->type, .code=CORD_all("userdata->", name));
+ if (!b) {
+ for (fn_ctx_t *fn_ctx = env->fn_ctx; fn_ctx; fn_ctx = fn_ctx->parent) {
+ if (!fn_ctx->closure_scope) continue;
+ b = Table$str_get(*fn_ctx->closure_scope, name);
+ if (b) {
+ Table$str_set(env->fn_ctx->closed_vars, name, b);
+ binding_t *b2 = new(binding_t, .type=b->type, .code=CORD_all("userdata->", name));
+ Table$str_set(env->locals, name, b2);
+ return b2;
+ }
}
}
return b;
diff --git a/environment.h b/environment.h
index be87b857..80723b0f 100644
--- a/environment.h
+++ b/environment.h
@@ -14,7 +14,8 @@ typedef struct {
CORD typeinfos;
} compilation_unit_t;
-typedef struct {
+typedef struct fn_ctx_s {
+ struct fn_ctx_s *parent;
type_t *return_type;
table_t *closure_scope;
table_t *closed_vars;
diff --git a/test/lambdas.tm b/test/lambdas.tm
index 15a8f23b..cb5bf3d4 100644
--- a/test/lambdas.tm
+++ b/test/lambdas.tm
@@ -30,3 +30,13 @@ func main():
>> abs100 := mul_func(100, Int.abs)
>> abs100(-5)
= 500
+
+ // Test nested lambdas:
+ outer := "Hello"
+ fn := func():
+ return func():
+ return func():
+ defer: |{outer}
+ return outer
+ >> fn()()()
+ = "Hello"