aboutsummaryrefslogtreecommitdiff
path: root/compile.c
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2025-02-09 13:57:54 -0500
committerBruce Hill <bruce@bruce-hill.com>2025-02-09 13:57:54 -0500
commitbe87d8169d98eb358891d155ce84cff311ec27c2 (patch)
treea609244e8c1eae1f2a8679da4b8bdc83fdc8cf20 /compile.c
parent6310f0565641dd64ae435c6a352c76e54fb9ddba (diff)
Convert the logic for finding closed variables to a more pure functional
style with fewer side effects
Diffstat (limited to 'compile.c')
-rw-r--r--compile.c377
1 files changed, 303 insertions, 74 deletions
diff --git a/compile.c b/compile.c
index 9c1a16f0..ecd03b88 100644
--- a/compile.c
+++ b/compile.c
@@ -179,40 +179,291 @@ CORD compile_maybe_incref(env_t *env, ast_t *ast, type_t *t)
return compile_to_type(env, ast, t);
}
+static void add_closed_vars(Table_t *closed_vars, env_t *enclosing_scope, env_t *env, ast_t *ast)
+{
+ if (ast == NULL)
+ return;
+
+ switch (ast->tag) {
+ case Var: {
+ binding_t *b = get_binding(enclosing_scope, Match(ast, Var)->name);
+ if (b) {
+ binding_t *shadow = get_binding(env, Match(ast, Var)->name);
+ if (!shadow || shadow == b)
+ Table$str_set(closed_vars, Match(ast, Var)->name, b);
+ }
+ break;
+ }
+ case TextJoin: {
+ for (ast_list_t *child = Match(ast, TextJoin)->children; child; child = child->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, child->ast);
+ break;
+ }
+ case PrintStatement: {
+ for (ast_list_t *child = Match(ast, PrintStatement)->to_print; child; child = child->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, child->ast);
+ break;
+ }
+ case Declare: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Declare)->value);
+ bind_statement(env, ast);
+ break;
+ }
+ case Assign: {
+ for (ast_list_t *target = Match(ast, Assign)->targets; target; target = target->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, target->ast);
+ for (ast_list_t *value = Match(ast, Assign)->values; value; value = value->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, value->ast);
+ break;
+ }
+ case BinaryOp: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, BinaryOp)->lhs);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, BinaryOp)->rhs);
+ break;
+ }
+ case UpdateAssign: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, UpdateAssign)->lhs);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, UpdateAssign)->rhs);
+ break;
+ }
+ case Not: case Negative: case HeapAllocate: case StackReference: case Mutexed: {
+ // UNSAFE:
+ ast_t *value = ast->__data.Not.value;
+ // END UNSAFE
+ add_closed_vars(closed_vars, enclosing_scope, env, value);
+ break;
+ }
+ case Holding: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Holding)->mutexed);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Holding)->body);
+ break;
+ }
+ case Min: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Min)->lhs);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Min)->rhs);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Min)->key);
+ break;
+ }
+ case Max: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Max)->lhs);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Max)->rhs);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Max)->key);
+ break;
+ }
+ case Array: {
+ for (ast_list_t *item = Match(ast, Array)->items; item; item = item->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, item->ast);
+ break;
+ }
+ case Set: {
+ for (ast_list_t *item = Match(ast, Set)->items; item; item = item->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, item->ast);
+ break;
+ }
+ case Table: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Table)->default_value);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Table)->fallback);
+ for (ast_list_t *entry = Match(ast, Table)->entries; entry; entry = entry->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, entry->ast);
+ break;
+ }
+ case TableEntry: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, TableEntry)->key);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, TableEntry)->value);
+ break;
+ }
+ case Comprehension: {
+ auto comp = Match(ast, Comprehension);
+ if (comp->expr->tag == Comprehension) { // Nested comprehension
+ ast_t *body = comp->filter ? WrapAST(ast, If, .condition=comp->filter, .body=comp->expr) : comp->expr;
+ ast_t *loop = WrapAST(ast, For, .vars=comp->vars, .iter=comp->iter, .body=body);
+ return add_closed_vars(closed_vars, enclosing_scope, env, loop);
+ }
+
+ // Array/Set/Table comprehension:
+ ast_t *body = comp->expr;
+ if (comp->filter)
+ body = WrapAST(comp->expr, If, .condition=comp->filter, .body=body);
+ ast_t *loop = WrapAST(ast, For, .vars=comp->vars, .iter=comp->iter, .body=body);
+ add_closed_vars(closed_vars, enclosing_scope, env, loop);
+ break;
+ }
+ case Lambda: {
+ auto lambda = Match(ast, Lambda);
+ env_t *lambda_scope = fresh_scope(env);
+ for (arg_ast_t *arg = lambda->args; arg; arg = arg->next)
+ set_binding(lambda_scope, arg->name, new(binding_t, .type=get_arg_ast_type(env, arg)));
+ add_closed_vars(closed_vars, enclosing_scope, lambda_scope, lambda->body);
+ break;
+ }
+ case FunctionCall: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, FunctionCall)->fn);
+ for (arg_ast_t *arg = Match(ast, FunctionCall)->args; arg; arg = arg->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, arg->value);
+ break;
+ }
+ case MethodCall: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, MethodCall)->self);
+ for (arg_ast_t *arg = Match(ast, MethodCall)->args; arg; arg = arg->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, arg->value);
+ break;
+ }
+ case Block: {
+ env = fresh_scope(env);
+ for (ast_list_t *statement = Match(ast, Block)->statements; statement; statement = statement->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, statement->ast);
+ break;
+ }
+ case For: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, For)->iter);
+ env_t *body_scope = for_scope(env, ast);
+ add_closed_vars(closed_vars, enclosing_scope, body_scope, Match(ast, For)->body);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, For)->empty);
+ break;
+ }
+ case While: {
+ auto while_ = Match(ast, While);
+ add_closed_vars(closed_vars, enclosing_scope, env, while_->condition);
+ env_t *scope = fresh_scope(env);
+ add_closed_vars(closed_vars, enclosing_scope, scope, while_->body);
+ break;
+ }
+ case If: {
+ auto if_ = Match(ast, If);
+ ast_t *condition = if_->condition;
+ add_closed_vars(closed_vars, enclosing_scope, env, condition);
+ if (condition->tag == Declare) {
+ env_t *truthy_scope = fresh_scope(env);
+ bind_statement(truthy_scope, condition);
+ ast_t *var = Match(condition, Declare)->var;
+ type_t *cond_t = get_type(truthy_scope, var);
+ if (cond_t->tag == OptionalType) {
+ set_binding(truthy_scope, Match(var, Var)->name,
+ new(binding_t, .type=Match(cond_t, OptionalType)->type));
+ }
+ add_closed_vars(closed_vars, enclosing_scope, truthy_scope, if_->body);
+ add_closed_vars(closed_vars, enclosing_scope, env, if_->else_body);
+ } else {
+ env_t *truthy_scope = env;
+ type_t *cond_t = get_type(env, condition);
+ if (condition->tag == Var && cond_t->tag == OptionalType) {
+ truthy_scope = fresh_scope(env);
+ set_binding(truthy_scope, Match(condition, Var)->name,
+ new(binding_t, .type=Match(cond_t, OptionalType)->type));
+ }
+ add_closed_vars(closed_vars, enclosing_scope, truthy_scope, if_->body);
+ add_closed_vars(closed_vars, enclosing_scope, env, if_->else_body);
+ }
+ break;
+ }
+ case When: {
+ auto when = Match(ast, When);
+ add_closed_vars(closed_vars, enclosing_scope, env, when->subject);
+ type_t *subject_t = get_type(env, when->subject);
+
+ auto enum_t = Match(subject_t, EnumType);
+ for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
+ const char *clause_tag_name = Match(clause->tag_name, Var)->name;
+ type_t *tag_type = NULL;
+ for (tag_t *tag = enum_t->tags; tag; tag = tag->next) {
+ if (streq(tag->name, clause_tag_name)) {
+ tag_type = tag->type;
+ break;
+ }
+ }
+ assert(tag_type);
+ env_t *scope = env;
-static Table_t *get_closed_vars(env_t *env, ast_t *lambda_ast)
+ auto tag_struct = Match(tag_type, StructType);
+ if (clause->args && !clause->args->next && tag_struct->fields && tag_struct->fields->next) {
+ scope = fresh_scope(scope);
+ set_binding(scope, Match(clause->args->ast, Var)->name, new(binding_t, .type=tag_type));
+ } else if (clause->args) {
+ scope = fresh_scope(scope);
+ ast_list_t *var = clause->args;
+ arg_t *field = tag_struct->fields;
+ while (var || field) {
+ if (!var)
+ code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, clause_tag_name, field->name);
+ if (!field)
+ code_err(var->ast, "This is one more field than %T has", subject_t);
+ set_binding(scope, Match(var->ast, Var)->name, new(binding_t, .type=field->type));
+ var = var->next;
+ field = field->next;
+ }
+ }
+ add_closed_vars(closed_vars, enclosing_scope, scope, clause->body);
+ }
+ if (when->else_body)
+ add_closed_vars(closed_vars, enclosing_scope, env, when->else_body);
+ break;
+ }
+ case Repeat: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Repeat)->body);
+ break;
+ }
+ case Reduction: {
+ auto reduction = Match(ast, Reduction);
+ static int64_t next_id = 1;
+ ast_t *item = FakeAST(Var, heap_strf("$it%ld", next_id++));
+ ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=FakeAST(Pass));
+ env_t *scope = for_scope(env, loop);
+ add_closed_vars(closed_vars, enclosing_scope, scope, reduction->key ? reduction->key : item);
+ break;
+ }
+ case Defer: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Defer)->body);
+ break;
+ }
+ case Return: {
+ ast_t *ret = Match(ast, Return)->value;
+ if (ret) add_closed_vars(closed_vars, enclosing_scope, env, ret);
+ break;
+ }
+ case Index: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Index)->indexed);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Index)->index);
+ break;
+ }
+ case FieldAccess: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, FieldAccess)->fielded);
+ break;
+ }
+ case Optional: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Optional)->value);
+ break;
+ }
+ case NonOptional: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, NonOptional)->value);
+ break;
+ }
+ case DocTest: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, DocTest)->expr);
+ break;
+ }
+ case Deserialize: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Deserialize)->value);
+ break;
+ }
+ case Use: case FunctionDef: case StructDef: case EnumDef: case LangDef: {
+ errx(1, "Definitions should not be reachable in a closure.");
+ }
+ default:
+ break;
+ }
+}
+
+static Table_t get_closed_vars(env_t *env, arg_ast_t *args, ast_t *block)
{
- auto lambda = Match(lambda_ast, Lambda);
env_t *body_scope = fresh_scope(env);
- body_scope->code = new(compilation_unit_t); // Don't put any code in the headers or anything
- for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) {
+ for (arg_ast_t *arg = args; arg; arg = arg->next) {
type_t *arg_type = get_arg_ast_type(env, arg);
set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=CORD_cat("_$", arg->name)));
}
- fn_ctx_t fn_ctx = (fn_ctx_t){
- .parent=env->fn_ctx,
- .closure_scope=env->locals,
- .closed_vars=new(Table_t),
- };
- body_scope->fn_ctx = &fn_ctx;
- body_scope->locals->fallback = env->globals;
- type_t *ret_t = get_type(body_scope, lambda->body);
- if (ret_t->tag == ReturnType)
- ret_t = Match(ret_t, ReturnType)->ret;
- fn_ctx.return_type = ret_t;
-
- // Find which variables are captured in the closure:
- env_t *tmp_scope = fresh_scope(body_scope);
- for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next) {
- type_t *stmt_type = get_type(tmp_scope, stmt->ast);
- if (stmt->next || (stmt_type->tag == VoidType || stmt_type->tag == AbortType || get_type(tmp_scope, stmt->ast)->tag == ReturnType))
- (void)compile_statement(tmp_scope, stmt->ast);
- else
- (void)compile(tmp_scope, stmt->ast);
- bind_statement(tmp_scope, stmt->ast);
- }
- return fn_ctx.closed_vars;
+ Table_t closed_vars = {};
+ add_closed_vars(&closed_vars, env, body_scope, block);
+ return closed_vars;
}
CORD compile_declaration(type_t *t, CORD name)
@@ -953,13 +1204,7 @@ static CORD _compile_statement(env_t *env, ast_t *ast)
set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=CORD_cat("_$", arg->name)));
}
- fn_ctx_t fn_ctx = (fn_ctx_t){
- .parent=NULL,
- .return_type=ret_t,
- .closure_scope=NULL,
- .closed_vars=NULL,
- };
- body_scope->fn_ctx = &fn_ctx;
+ body_scope->fn_ret = ret_t;
type_t *body_type = get_type(body_scope, fndef->body);
if (ret_t->tag != VoidType && ret_t->tag != AbortType && body_type->tag != AbortType && body_type->tag != ReturnType)
@@ -1132,13 +1377,13 @@ static CORD _compile_statement(env_t *env, ast_t *ast)
case Pass: return ";";
case Defer: {
ast_t *body = Match(ast, Defer)->body;
- Table_t *closed_vars = get_closed_vars(env, FakeAST(Lambda, .args=NULL, .body=body));
+ Table_t closed_vars = get_closed_vars(env, NULL, body);
static int defer_id = 0;
env_t *defer_env = fresh_scope(env);
CORD code = CORD_EMPTY;
- for (int64_t i = 1; i <= Table$length(*closed_vars); i++) {
- struct { const char *name; binding_t *b; } *entry = Table$entry(*closed_vars, i);
+ for (int64_t i = 1; i <= Table$length(closed_vars); i++) {
+ struct { const char *name; binding_t *b; } *entry = Table$entry(closed_vars, i);
if (entry->b->type->tag == ModuleType)
continue;
if (CORD_ncmp(entry->b->code, 0, "userdata->", 0, strlen("userdata->")) == 0) {
@@ -1149,9 +1394,6 @@ static CORD _compile_statement(env_t *env, ast_t *ast)
code, compile_declaration(entry->b->type, defer_name), " = ", entry->b->code, ";\n");
set_binding(defer_env, entry->name, new(binding_t, .type=entry->b->type, .code=defer_name));
}
-
- if (env->fn_ctx->closed_vars)
- Table$str_set(env->fn_ctx->closed_vars, entry->name, entry->b);
}
env->deferred = new(deferral_t, .defer_env=defer_env, .block=body, .next=env->deferred);
return code;
@@ -1175,9 +1417,8 @@ static CORD _compile_statement(env_t *env, ast_t *ast)
return CORD_cat(code, ", yes);");
}
case Return: {
- if (!env->fn_ctx) code_err(ast, "This return statement is not inside any function");
+ if (!env->fn_ret) code_err(ast, "This return statement is not inside any function");
auto ret = Match(ast, Return)->value;
- assert(env->fn_ctx->return_type);
CORD code = CORD_EMPTY;
for (deferral_t *deferred = env->deferred; deferred; deferred = deferred->next) {
@@ -1185,20 +1426,20 @@ static CORD _compile_statement(env_t *env, ast_t *ast)
}
if (ret) {
- if (env->fn_ctx->return_type->tag == VoidType || env->fn_ctx->return_type->tag == AbortType)
+ if (env->fn_ret->tag == VoidType || env->fn_ret->tag == AbortType)
code_err(ast, "This function is not supposed to return any values, according to its type signature");
- env = with_enum_scope(env, env->fn_ctx->return_type);
- CORD value = compile_to_type(env, ret, env->fn_ctx->return_type);
+ env = with_enum_scope(env, env->fn_ret);
+ CORD value = compile_to_type(env, ret, env->fn_ret);
if (env->deferred) {
- code = CORD_all(compile_declaration(env->fn_ctx->return_type, "ret"), " = ", value, ";\n", code);
+ code = CORD_all(compile_declaration(env->fn_ret, "ret"), " = ", value, ";\n", code);
value = "ret";
}
return CORD_all(code, "return ", value, ";");
} else {
- if (env->fn_ctx->return_type->tag != VoidType)
- code_err(ast, "This function expects you to return a %T value", env->fn_ctx->return_type);
+ if (env->fn_ret->tag != VoidType)
+ code_err(ast, "This function expects you to return a %T value", env->fn_ret);
return CORD_all(code, "return;");
}
}
@@ -2826,19 +3067,12 @@ CORD compile(env_t *env, ast_t *ast)
name, CORD_quoted(type_to_cord(get_type(env, ast))), file_base_name(ast->file->filename), get_line_number(ast->file, ast->start)));
env_t *body_scope = fresh_scope(env);
+ body_scope->deferred = NULL;
for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) {
type_t *arg_type = get_arg_ast_type(env, arg);
- set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=CORD_cat("_$", arg->name)));
+ set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=CORD_all("_$", arg->name)));
}
- fn_ctx_t fn_ctx = (fn_ctx_t){
- .parent=env->fn_ctx,
- .closure_scope=env->locals,
- .closed_vars=new(Table_t),
- };
- body_scope->fn_ctx = &fn_ctx;
- body_scope->locals->fallback = env->globals;
- body_scope->deferred = NULL;
type_t *ret_t = get_type(body_scope, lambda->body);
if (ret_t->tag == ReturnType)
ret_t = Match(ret_t, ReturnType)->ret;
@@ -2852,21 +3086,13 @@ CORD compile(env_t *env, ast_t *ast)
declared, ret_t);
}
- fn_ctx.return_type = ret_t;
-
- if (env->fn_ctx->closed_vars) {
- for (int64_t i = 1; i <= Table$length(*env->fn_ctx->closed_vars); i++) {
- struct { const char *name; binding_t *b; } *entry = Table$entry(*env->fn_ctx->closed_vars, i);
- set_binding(body_scope, entry->name, new(binding_t, .type=entry->b->type, .code=CORD_cat("userdata->", entry->name)));
- Table$str_set(fn_ctx.closed_vars, entry->name, entry->b);
- }
- }
+ body_scope->fn_ret = ret_t;
- Table_t *closed_vars = get_closed_vars(env, ast);
- if (Table$length(*closed_vars) > 0) { // Create a typedef for the lambda's closure userdata
+ Table_t closed_vars = get_closed_vars(env, lambda->args, ast);
+ if (Table$length(closed_vars) > 0) { // Create a typedef for the lambda's closure userdata
CORD def = "typedef struct {";
- for (int64_t i = 1; i <= Table$length(*closed_vars); i++) {
- struct { const char *name; binding_t *b; } *entry = Table$entry(*closed_vars, i);
+ for (int64_t i = 1; i <= Table$length(closed_vars); i++) {
+ struct { const char *name; binding_t *b; } *entry = Table$entry(closed_vars, i);
if (has_stack_memory(entry->b->type))
code_err(ast, "This function is holding onto a reference to %T stack memory in the variable `%s`, but the function may outlive the stack memory",
entry->b->type, entry->name);
@@ -2886,16 +3112,19 @@ CORD compile(env_t *env, ast_t *ast)
}
CORD userdata;
- if (Table$length(*closed_vars) == 0) {
+ if (Table$length(closed_vars) == 0) {
code = CORD_cat(code, "void *)");
userdata = "NULL";
} else {
userdata = CORD_all("new(", name, "$userdata_t");
- for (int64_t i = 1; i <= Table$length(*closed_vars); i++) {
- struct { const char *name; binding_t *b; } *entry = Table$entry(*closed_vars, i);
+ for (int64_t i = 1; i <= Table$length(closed_vars); i++) {
+ struct { const char *name; binding_t *b; } *entry = Table$entry(closed_vars, i);
if (entry->b->type->tag == ModuleType)
continue;
- CORD binding_code = get_binding(env, entry->name)->code;
+ binding_t *b = get_binding(env, entry->name);
+ if (!b) printf("Couldn't find: %s\n", entry->name);
+ assert(b);
+ CORD binding_code = b->code;
if (entry->b->type->tag == ArrayType)
userdata = CORD_all(userdata, ", ARRAY_COPY(", binding_code, ")");
else if (entry->b->type->tag == TableType || entry->b->type->tag == SetType)