aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2025-02-09 13:57:54 -0500
committerBruce Hill <bruce@bruce-hill.com>2025-02-09 13:57:54 -0500
commitbe87d8169d98eb358891d155ce84cff311ec27c2 (patch)
treea609244e8c1eae1f2a8679da4b8bdc83fdc8cf20
parent6310f0565641dd64ae435c6a352c76e54fb9ddba (diff)
Convert the logic for finding closed variables to a more pure functional
style with fewer side effects
-rw-r--r--compile.c377
-rw-r--r--environment.c19
-rw-r--r--environment.h9
-rw-r--r--repl.c6
-rw-r--r--typecheck.c6
5 files changed, 310 insertions, 107 deletions
diff --git a/compile.c b/compile.c
index 9c1a16f0..ecd03b88 100644
--- a/compile.c
+++ b/compile.c
@@ -179,40 +179,291 @@ CORD compile_maybe_incref(env_t *env, ast_t *ast, type_t *t)
return compile_to_type(env, ast, t);
}
+static void add_closed_vars(Table_t *closed_vars, env_t *enclosing_scope, env_t *env, ast_t *ast)
+{
+ if (ast == NULL)
+ return;
+
+ switch (ast->tag) {
+ case Var: {
+ binding_t *b = get_binding(enclosing_scope, Match(ast, Var)->name);
+ if (b) {
+ binding_t *shadow = get_binding(env, Match(ast, Var)->name);
+ if (!shadow || shadow == b)
+ Table$str_set(closed_vars, Match(ast, Var)->name, b);
+ }
+ break;
+ }
+ case TextJoin: {
+ for (ast_list_t *child = Match(ast, TextJoin)->children; child; child = child->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, child->ast);
+ break;
+ }
+ case PrintStatement: {
+ for (ast_list_t *child = Match(ast, PrintStatement)->to_print; child; child = child->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, child->ast);
+ break;
+ }
+ case Declare: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Declare)->value);
+ bind_statement(env, ast);
+ break;
+ }
+ case Assign: {
+ for (ast_list_t *target = Match(ast, Assign)->targets; target; target = target->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, target->ast);
+ for (ast_list_t *value = Match(ast, Assign)->values; value; value = value->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, value->ast);
+ break;
+ }
+ case BinaryOp: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, BinaryOp)->lhs);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, BinaryOp)->rhs);
+ break;
+ }
+ case UpdateAssign: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, UpdateAssign)->lhs);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, UpdateAssign)->rhs);
+ break;
+ }
+ case Not: case Negative: case HeapAllocate: case StackReference: case Mutexed: {
+ // UNSAFE:
+ ast_t *value = ast->__data.Not.value;
+ // END UNSAFE
+ add_closed_vars(closed_vars, enclosing_scope, env, value);
+ break;
+ }
+ case Holding: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Holding)->mutexed);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Holding)->body);
+ break;
+ }
+ case Min: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Min)->lhs);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Min)->rhs);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Min)->key);
+ break;
+ }
+ case Max: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Max)->lhs);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Max)->rhs);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Max)->key);
+ break;
+ }
+ case Array: {
+ for (ast_list_t *item = Match(ast, Array)->items; item; item = item->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, item->ast);
+ break;
+ }
+ case Set: {
+ for (ast_list_t *item = Match(ast, Set)->items; item; item = item->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, item->ast);
+ break;
+ }
+ case Table: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Table)->default_value);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Table)->fallback);
+ for (ast_list_t *entry = Match(ast, Table)->entries; entry; entry = entry->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, entry->ast);
+ break;
+ }
+ case TableEntry: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, TableEntry)->key);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, TableEntry)->value);
+ break;
+ }
+ case Comprehension: {
+ auto comp = Match(ast, Comprehension);
+ if (comp->expr->tag == Comprehension) { // Nested comprehension
+ ast_t *body = comp->filter ? WrapAST(ast, If, .condition=comp->filter, .body=comp->expr) : comp->expr;
+ ast_t *loop = WrapAST(ast, For, .vars=comp->vars, .iter=comp->iter, .body=body);
+ return add_closed_vars(closed_vars, enclosing_scope, env, loop);
+ }
+
+ // Array/Set/Table comprehension:
+ ast_t *body = comp->expr;
+ if (comp->filter)
+ body = WrapAST(comp->expr, If, .condition=comp->filter, .body=body);
+ ast_t *loop = WrapAST(ast, For, .vars=comp->vars, .iter=comp->iter, .body=body);
+ add_closed_vars(closed_vars, enclosing_scope, env, loop);
+ break;
+ }
+ case Lambda: {
+ auto lambda = Match(ast, Lambda);
+ env_t *lambda_scope = fresh_scope(env);
+ for (arg_ast_t *arg = lambda->args; arg; arg = arg->next)
+ set_binding(lambda_scope, arg->name, new(binding_t, .type=get_arg_ast_type(env, arg)));
+ add_closed_vars(closed_vars, enclosing_scope, lambda_scope, lambda->body);
+ break;
+ }
+ case FunctionCall: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, FunctionCall)->fn);
+ for (arg_ast_t *arg = Match(ast, FunctionCall)->args; arg; arg = arg->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, arg->value);
+ break;
+ }
+ case MethodCall: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, MethodCall)->self);
+ for (arg_ast_t *arg = Match(ast, MethodCall)->args; arg; arg = arg->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, arg->value);
+ break;
+ }
+ case Block: {
+ env = fresh_scope(env);
+ for (ast_list_t *statement = Match(ast, Block)->statements; statement; statement = statement->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, statement->ast);
+ break;
+ }
+ case For: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, For)->iter);
+ env_t *body_scope = for_scope(env, ast);
+ add_closed_vars(closed_vars, enclosing_scope, body_scope, Match(ast, For)->body);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, For)->empty);
+ break;
+ }
+ case While: {
+ auto while_ = Match(ast, While);
+ add_closed_vars(closed_vars, enclosing_scope, env, while_->condition);
+ env_t *scope = fresh_scope(env);
+ add_closed_vars(closed_vars, enclosing_scope, scope, while_->body);
+ break;
+ }
+ case If: {
+ auto if_ = Match(ast, If);
+ ast_t *condition = if_->condition;
+ add_closed_vars(closed_vars, enclosing_scope, env, condition);
+ if (condition->tag == Declare) {
+ env_t *truthy_scope = fresh_scope(env);
+ bind_statement(truthy_scope, condition);
+ ast_t *var = Match(condition, Declare)->var;
+ type_t *cond_t = get_type(truthy_scope, var);
+ if (cond_t->tag == OptionalType) {
+ set_binding(truthy_scope, Match(var, Var)->name,
+ new(binding_t, .type=Match(cond_t, OptionalType)->type));
+ }
+ add_closed_vars(closed_vars, enclosing_scope, truthy_scope, if_->body);
+ add_closed_vars(closed_vars, enclosing_scope, env, if_->else_body);
+ } else {
+ env_t *truthy_scope = env;
+ type_t *cond_t = get_type(env, condition);
+ if (condition->tag == Var && cond_t->tag == OptionalType) {
+ truthy_scope = fresh_scope(env);
+ set_binding(truthy_scope, Match(condition, Var)->name,
+ new(binding_t, .type=Match(cond_t, OptionalType)->type));
+ }
+ add_closed_vars(closed_vars, enclosing_scope, truthy_scope, if_->body);
+ add_closed_vars(closed_vars, enclosing_scope, env, if_->else_body);
+ }
+ break;
+ }
+ case When: {
+ auto when = Match(ast, When);
+ add_closed_vars(closed_vars, enclosing_scope, env, when->subject);
+ type_t *subject_t = get_type(env, when->subject);
+
+ auto enum_t = Match(subject_t, EnumType);
+ for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
+ const char *clause_tag_name = Match(clause->tag_name, Var)->name;
+ type_t *tag_type = NULL;
+ for (tag_t *tag = enum_t->tags; tag; tag = tag->next) {
+ if (streq(tag->name, clause_tag_name)) {
+ tag_type = tag->type;
+ break;
+ }
+ }
+ assert(tag_type);
+ env_t *scope = env;
-static Table_t *get_closed_vars(env_t *env, ast_t *lambda_ast)
+ auto tag_struct = Match(tag_type, StructType);
+ if (clause->args && !clause->args->next && tag_struct->fields && tag_struct->fields->next) {
+ scope = fresh_scope(scope);
+ set_binding(scope, Match(clause->args->ast, Var)->name, new(binding_t, .type=tag_type));
+ } else if (clause->args) {
+ scope = fresh_scope(scope);
+ ast_list_t *var = clause->args;
+ arg_t *field = tag_struct->fields;
+ while (var || field) {
+ if (!var)
+ code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, clause_tag_name, field->name);
+ if (!field)
+ code_err(var->ast, "This is one more field than %T has", subject_t);
+ set_binding(scope, Match(var->ast, Var)->name, new(binding_t, .type=field->type));
+ var = var->next;
+ field = field->next;
+ }
+ }
+ add_closed_vars(closed_vars, enclosing_scope, scope, clause->body);
+ }
+ if (when->else_body)
+ add_closed_vars(closed_vars, enclosing_scope, env, when->else_body);
+ break;
+ }
+ case Repeat: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Repeat)->body);
+ break;
+ }
+ case Reduction: {
+ auto reduction = Match(ast, Reduction);
+ static int64_t next_id = 1;
+ ast_t *item = FakeAST(Var, heap_strf("$it%ld", next_id++));
+ ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=FakeAST(Pass));
+ env_t *scope = for_scope(env, loop);
+ add_closed_vars(closed_vars, enclosing_scope, scope, reduction->key ? reduction->key : item);
+ break;
+ }
+ case Defer: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Defer)->body);
+ break;
+ }
+ case Return: {
+ ast_t *ret = Match(ast, Return)->value;
+ if (ret) add_closed_vars(closed_vars, enclosing_scope, env, ret);
+ break;
+ }
+ case Index: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Index)->indexed);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Index)->index);
+ break;
+ }
+ case FieldAccess: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, FieldAccess)->fielded);
+ break;
+ }
+ case Optional: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Optional)->value);
+ break;
+ }
+ case NonOptional: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, NonOptional)->value);
+ break;
+ }
+ case DocTest: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, DocTest)->expr);
+ break;
+ }
+ case Deserialize: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Deserialize)->value);
+ break;
+ }
+ case Use: case FunctionDef: case StructDef: case EnumDef: case LangDef: {
+ errx(1, "Definitions should not be reachable in a closure.");
+ }
+ default:
+ break;
+ }
+}
+
+static Table_t get_closed_vars(env_t *env, arg_ast_t *args, ast_t *block)
{
- auto lambda = Match(lambda_ast, Lambda);
env_t *body_scope = fresh_scope(env);
- body_scope->code = new(compilation_unit_t); // Don't put any code in the headers or anything
- for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) {
+ for (arg_ast_t *arg = args; arg; arg = arg->next) {
type_t *arg_type = get_arg_ast_type(env, arg);
set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=CORD_cat("_$", arg->name)));
}
- fn_ctx_t fn_ctx = (fn_ctx_t){
- .parent=env->fn_ctx,
- .closure_scope=env->locals,
- .closed_vars=new(Table_t),
- };
- body_scope->fn_ctx = &fn_ctx;
- body_scope->locals->fallback = env->globals;
- type_t *ret_t = get_type(body_scope, lambda->body);
- if (ret_t->tag == ReturnType)
- ret_t = Match(ret_t, ReturnType)->ret;
- fn_ctx.return_type = ret_t;
-
- // Find which variables are captured in the closure:
- env_t *tmp_scope = fresh_scope(body_scope);
- for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next) {
- type_t *stmt_type = get_type(tmp_scope, stmt->ast);
- if (stmt->next || (stmt_type->tag == VoidType || stmt_type->tag == AbortType || get_type(tmp_scope, stmt->ast)->tag == ReturnType))
- (void)compile_statement(tmp_scope, stmt->ast);
- else
- (void)compile(tmp_scope, stmt->ast);
- bind_statement(tmp_scope, stmt->ast);
- }
- return fn_ctx.closed_vars;
+ Table_t closed_vars = {};
+ add_closed_vars(&closed_vars, env, body_scope, block);
+ return closed_vars;
}
CORD compile_declaration(type_t *t, CORD name)
@@ -953,13 +1204,7 @@ static CORD _compile_statement(env_t *env, ast_t *ast)
set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=CORD_cat("_$", arg->name)));
}
- fn_ctx_t fn_ctx = (fn_ctx_t){
- .parent=NULL,
- .return_type=ret_t,
- .closure_scope=NULL,
- .closed_vars=NULL,
- };
- body_scope->fn_ctx = &fn_ctx;
+ body_scope->fn_ret = ret_t;
type_t *body_type = get_type(body_scope, fndef->body);
if (ret_t->tag != VoidType && ret_t->tag != AbortType && body_type->tag != AbortType && body_type->tag != ReturnType)
@@ -1132,13 +1377,13 @@ static CORD _compile_statement(env_t *env, ast_t *ast)
case Pass: return ";";
case Defer: {
ast_t *body = Match(ast, Defer)->body;
- Table_t *closed_vars = get_closed_vars(env, FakeAST(Lambda, .args=NULL, .body=body));
+ Table_t closed_vars = get_closed_vars(env, NULL, body);
static int defer_id = 0;
env_t *defer_env = fresh_scope(env);
CORD code = CORD_EMPTY;
- for (int64_t i = 1; i <= Table$length(*closed_vars); i++) {
- struct { const char *name; binding_t *b; } *entry = Table$entry(*closed_vars, i);
+ for (int64_t i = 1; i <= Table$length(closed_vars); i++) {
+ struct { const char *name; binding_t *b; } *entry = Table$entry(closed_vars, i);
if (entry->b->type->tag == ModuleType)
continue;
if (CORD_ncmp(entry->b->code, 0, "userdata->", 0, strlen("userdata->")) == 0) {
@@ -1149,9 +1394,6 @@ static CORD _compile_statement(env_t *env, ast_t *ast)
code, compile_declaration(entry->b->type, defer_name), " = ", entry->b->code, ";\n");
set_binding(defer_env, entry->name, new(binding_t, .type=entry->b->type, .code=defer_name));
}
-
- if (env->fn_ctx->closed_vars)
- Table$str_set(env->fn_ctx->closed_vars, entry->name, entry->b);
}
env->deferred = new(deferral_t, .defer_env=defer_env, .block=body, .next=env->deferred);
return code;
@@ -1175,9 +1417,8 @@ static CORD _compile_statement(env_t *env, ast_t *ast)
return CORD_cat(code, ", yes);");
}
case Return: {
- if (!env->fn_ctx) code_err(ast, "This return statement is not inside any function");
+ if (!env->fn_ret) code_err(ast, "This return statement is not inside any function");
auto ret = Match(ast, Return)->value;
- assert(env->fn_ctx->return_type);
CORD code = CORD_EMPTY;
for (deferral_t *deferred = env->deferred; deferred; deferred = deferred->next) {
@@ -1185,20 +1426,20 @@ static CORD _compile_statement(env_t *env, ast_t *ast)
}
if (ret) {
- if (env->fn_ctx->return_type->tag == VoidType || env->fn_ctx->return_type->tag == AbortType)
+ if (env->fn_ret->tag == VoidType || env->fn_ret->tag == AbortType)
code_err(ast, "This function is not supposed to return any values, according to its type signature");
- env = with_enum_scope(env, env->fn_ctx->return_type);
- CORD value = compile_to_type(env, ret, env->fn_ctx->return_type);
+ env = with_enum_scope(env, env->fn_ret);
+ CORD value = compile_to_type(env, ret, env->fn_ret);
if (env->deferred) {
- code = CORD_all(compile_declaration(env->fn_ctx->return_type, "ret"), " = ", value, ";\n", code);
+ code = CORD_all(compile_declaration(env->fn_ret, "ret"), " = ", value, ";\n", code);
value = "ret";
}
return CORD_all(code, "return ", value, ";");
} else {
- if (env->fn_ctx->return_type->tag != VoidType)
- code_err(ast, "This function expects you to return a %T value", env->fn_ctx->return_type);
+ if (env->fn_ret->tag != VoidType)
+ code_err(ast, "This function expects you to return a %T value", env->fn_ret);
return CORD_all(code, "return;");
}
}
@@ -2826,19 +3067,12 @@ CORD compile(env_t *env, ast_t *ast)
name, CORD_quoted(type_to_cord(get_type(env, ast))), file_base_name(ast->file->filename), get_line_number(ast->file, ast->start)));
env_t *body_scope = fresh_scope(env);
+ body_scope->deferred = NULL;
for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) {
type_t *arg_type = get_arg_ast_type(env, arg);
- set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=CORD_cat("_$", arg->name)));
+ set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=CORD_all("_$", arg->name)));
}
- fn_ctx_t fn_ctx = (fn_ctx_t){
- .parent=env->fn_ctx,
- .closure_scope=env->locals,
- .closed_vars=new(Table_t),
- };
- body_scope->fn_ctx = &fn_ctx;
- body_scope->locals->fallback = env->globals;
- body_scope->deferred = NULL;
type_t *ret_t = get_type(body_scope, lambda->body);
if (ret_t->tag == ReturnType)
ret_t = Match(ret_t, ReturnType)->ret;
@@ -2852,21 +3086,13 @@ CORD compile(env_t *env, ast_t *ast)
declared, ret_t);
}
- fn_ctx.return_type = ret_t;
-
- if (env->fn_ctx->closed_vars) {
- for (int64_t i = 1; i <= Table$length(*env->fn_ctx->closed_vars); i++) {
- struct { const char *name; binding_t *b; } *entry = Table$entry(*env->fn_ctx->closed_vars, i);
- set_binding(body_scope, entry->name, new(binding_t, .type=entry->b->type, .code=CORD_cat("userdata->", entry->name)));
- Table$str_set(fn_ctx.closed_vars, entry->name, entry->b);
- }
- }
+ body_scope->fn_ret = ret_t;
- Table_t *closed_vars = get_closed_vars(env, ast);
- if (Table$length(*closed_vars) > 0) { // Create a typedef for the lambda's closure userdata
+ Table_t closed_vars = get_closed_vars(env, lambda->args, ast);
+ if (Table$length(closed_vars) > 0) { // Create a typedef for the lambda's closure userdata
CORD def = "typedef struct {";
- for (int64_t i = 1; i <= Table$length(*closed_vars); i++) {
- struct { const char *name; binding_t *b; } *entry = Table$entry(*closed_vars, i);
+ for (int64_t i = 1; i <= Table$length(closed_vars); i++) {
+ struct { const char *name; binding_t *b; } *entry = Table$entry(closed_vars, i);
if (has_stack_memory(entry->b->type))
code_err(ast, "This function is holding onto a reference to %T stack memory in the variable `%s`, but the function may outlive the stack memory",
entry->b->type, entry->name);
@@ -2886,16 +3112,19 @@ CORD compile(env_t *env, ast_t *ast)
}
CORD userdata;
- if (Table$length(*closed_vars) == 0) {
+ if (Table$length(closed_vars) == 0) {
code = CORD_cat(code, "void *)");
userdata = "NULL";
} else {
userdata = CORD_all("new(", name, "$userdata_t");
- for (int64_t i = 1; i <= Table$length(*closed_vars); i++) {
- struct { const char *name; binding_t *b; } *entry = Table$entry(*closed_vars, i);
+ for (int64_t i = 1; i <= Table$length(closed_vars); i++) {
+ struct { const char *name; binding_t *b; } *entry = Table$entry(closed_vars, i);
if (entry->b->type->tag == ModuleType)
continue;
- CORD binding_code = get_binding(env, entry->name)->code;
+ binding_t *b = get_binding(env, entry->name);
+ if (!b) printf("Couldn't find: %s\n", entry->name);
+ assert(b);
+ CORD binding_code = b->code;
if (entry->b->type->tag == ArrayType)
userdata = CORD_all(userdata, ", ARRAY_COPY(", binding_code, ")");
else if (entry->b->type->tag == TableType || entry->b->type->tag == SetType)
diff --git a/environment.c b/environment.c
index 86a7ff16..1b33fd7a 100644
--- a/environment.c
+++ b/environment.c
@@ -646,22 +646,9 @@ env_t *namespace_env(env_t *env, const char *namespace_name)
return ns_env;
}
-binding_t *get_binding(env_t *env, const char *name)
+PUREFUNC binding_t *get_binding(env_t *env, const char *name)
{
- binding_t *b = Table$str_get(*env->locals, name);
- if (!b) {
- for (fn_ctx_t *fn_ctx = env->fn_ctx; fn_ctx; fn_ctx = fn_ctx->parent) {
- if (!fn_ctx->closure_scope) continue;
- b = Table$str_get(*fn_ctx->closure_scope, name);
- if (b) {
- Table$str_set(env->fn_ctx->closed_vars, name, b);
- binding_t *b2 = new(binding_t, .type=b->type, .code=CORD_all("userdata->", name));
- Table$str_set(env->locals, name, b2);
- return b2;
- }
- }
- }
- return b;
+ return Table$str_get(*env->locals, name);
}
binding_t *get_namespace_binding(env_t *env, ast_t *self, const char *name)
@@ -702,7 +689,7 @@ binding_t *get_namespace_binding(env_t *env, ast_t *self, const char *name)
return NULL;
}
-binding_t *get_lang_escape_function(env_t *env, const char *lang_name, type_t *type_to_escape)
+PUREFUNC binding_t *get_lang_escape_function(env_t *env, const char *lang_name, type_t *type_to_escape)
{
if (!lang_name) lang_name = "Text";
binding_t *typeinfo = get_binding(env, lang_name);
diff --git a/environment.h b/environment.h
index 616bfc58..465ecfaf 100644
--- a/environment.h
+++ b/environment.h
@@ -16,13 +16,6 @@ typedef struct {
CORD function_naming;
} compilation_unit_t;
-typedef struct fn_ctx_s {
- struct fn_ctx_s *parent;
- type_t *return_type;
- Table_t *closure_scope;
- Table_t *closed_vars;
-} fn_ctx_t;
-
typedef struct deferral_s {
struct deferral_s *next;
struct env_s *defer_env;
@@ -49,7 +42,7 @@ typedef struct env_s {
// - Raw 'use' string for module imports
Table_t *imports;
compilation_unit_t *code;
- fn_ctx_t *fn_ctx;
+ type_t *fn_ret;
loop_ctx_t *loop_ctx;
deferral_t *deferred;
CORD libname; // Currently compiling library name (if any)
diff --git a/repl.c b/repl.c
index 29ed407f..3b2e1af3 100644
--- a/repl.c
+++ b/repl.c
@@ -36,12 +36,6 @@ typedef struct {
static PUREFUNC repl_binding_t *get_repl_binding(env_t *env, const char *name)
{
repl_binding_t *b = Table$str_get(*env->locals, name);
- if (b) return b;
- for (fn_ctx_t *fn_ctx = env->fn_ctx; fn_ctx; fn_ctx = fn_ctx->parent) {
- if (!fn_ctx->closure_scope) continue;
- b = Table$str_get(*fn_ctx->closure_scope, name);
- if (b) return b;
- }
return b;
}
diff --git a/typecheck.c b/typecheck.c
index 877d56ab..a2530bcf 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -908,9 +908,9 @@ type_t *get_type(env_t *env, ast_t *ast)
case Return: {
ast_t *val = Match(ast, Return)->value;
// Support unqualified enum return values:
- if (env->fn_ctx && env->fn_ctx->return_type && env->fn_ctx->return_type->tag == EnumType) {
+ if (env->fn_ret && env->fn_ret->tag == EnumType) {
env = fresh_scope(env);
- auto enum_ = Match(env->fn_ctx->return_type, EnumType);
+ auto enum_ = Match(env->fn_ret, EnumType);
env_t *ns_env = enum_->env;
for (tag_t *tag = enum_->tags; tag; tag = tag->next) {
if (get_binding(env, tag->name))
@@ -1163,7 +1163,7 @@ type_t *get_type(env_t *env, ast_t *ast)
arg_t *args = NULL;
env_t *scope = fresh_scope(env); // For now, just use closed variables in scope normally
for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) {
- type_t *t = arg->type ? parse_type_ast(env, arg->type) : get_type(env, arg->value);
+ type_t *t = get_arg_ast_type(env, arg);
args = new(arg_t, .name=arg->name, .type=t, .next=args);
set_binding(scope, arg->name, new(binding_t, .type=t));
}