aboutsummaryrefslogtreecommitdiff
path: root/src/compile
diff options
context:
space:
mode:
Diffstat (limited to 'src/compile')
-rw-r--r--src/compile/functions.c376
-rw-r--r--src/compile/functions.h2
2 files changed, 378 insertions, 0 deletions
diff --git a/src/compile/functions.c b/src/compile/functions.c
index 1a2e5d3d..4015f0ca 100644
--- a/src/compile/functions.c
+++ b/src/compile/functions.c
@@ -1,6 +1,8 @@
+#include "functions.h"
#include "../ast.h"
#include "../compile.h"
#include "../environment.h"
+#include "../naming.h"
#include "../stdlib/datatypes.h"
#include "../stdlib/integers.h"
#include "../stdlib/nums.h"
@@ -195,3 +197,377 @@ Text_t compile_function_call(env_t *env, ast_t *ast) {
code_err(call->fn, "This is not a function, it's a ", type_to_str(fn_t));
}
}
+
+public
+Text_t compile_lambda(env_t *env, ast_t *ast) {
+ DeclareMatch(lambda, ast, Lambda);
+ Text_t name = namespace_name(env, env->namespace, Texts("lambda$", String(lambda->id)));
+
+ env_t *body_scope = fresh_scope(env);
+ body_scope->deferred = NULL;
+ for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) {
+ type_t *arg_type = get_arg_ast_type(env, arg);
+ set_binding(body_scope, arg->name, arg_type, Texts("_$", arg->name));
+ }
+
+ type_t *ret_t = get_type(body_scope, lambda->body);
+ if (ret_t->tag == ReturnType) ret_t = Match(ret_t, ReturnType)->ret;
+
+ if (lambda->ret_type) {
+ type_t *declared = parse_type_ast(env, lambda->ret_type);
+ if (can_promote(ret_t, declared)) ret_t = declared;
+ else
+ code_err(ast, "This function was declared to return a value of type ", type_to_str(declared),
+ ", but actually returns a value of type ", type_to_str(ret_t));
+ }
+
+ body_scope->fn_ret = ret_t;
+
+ Table_t closed_vars = get_closed_vars(env, lambda->args, ast);
+ if (Table$length(closed_vars) > 0) { // Create a typedef for the lambda's closure userdata
+ Text_t def = Text("typedef struct {");
+ for (int64_t i = 0; i < closed_vars.entries.length; i++) {
+ struct {
+ const char *name;
+ binding_t *b;
+ } *entry = closed_vars.entries.data + closed_vars.entries.stride * i;
+ if (has_stack_memory(entry->b->type))
+ code_err(ast, "This function is holding onto a reference to ", type_to_str(entry->b->type),
+ " stack memory in the variable `", entry->name,
+ "`, but the function may outlive the stack memory");
+ if (entry->b->type->tag == ModuleType) continue;
+ set_binding(body_scope, entry->name, entry->b->type, Texts("userdata->", entry->name));
+ def = Texts(def, compile_declaration(entry->b->type, Text$from_str(entry->name)), "; ");
+ }
+ def = Texts(def, "} ", name, "$userdata_t;");
+ env->code->local_typedefs = Texts(env->code->local_typedefs, def);
+ }
+
+ Text_t code = Texts("static ", compile_type(ret_t), " ", name, "(");
+ for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) {
+ type_t *arg_type = get_arg_ast_type(env, arg);
+ code = Texts(code, compile_type(arg_type), " _$", arg->name, ", ");
+ }
+
+ Text_t userdata;
+ if (Table$length(closed_vars) == 0) {
+ code = Texts(code, "void *_)");
+ userdata = Text("NULL");
+ } else {
+ userdata = Texts("new(", name, "$userdata_t");
+ for (int64_t i = 0; i < closed_vars.entries.length; i++) {
+ struct {
+ const char *name;
+ binding_t *b;
+ } *entry = closed_vars.entries.data + closed_vars.entries.stride * i;
+ if (entry->b->type->tag == ModuleType) continue;
+ binding_t *b = get_binding(env, entry->name);
+ assert(b);
+ Text_t binding_code = b->code;
+ if (entry->b->type->tag == ListType) userdata = Texts(userdata, ", LIST_COPY(", binding_code, ")");
+ else if (entry->b->type->tag == TableType || entry->b->type->tag == SetType)
+ userdata = Texts(userdata, ", TABLE_COPY(", binding_code, ")");
+ else userdata = Texts(userdata, ", ", binding_code);
+ }
+ userdata = Texts(userdata, ")");
+ code = Texts(code, name, "$userdata_t *userdata)");
+ }
+
+ Text_t body = EMPTY_TEXT;
+ for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next) {
+ if (stmt->next || ret_t->tag == VoidType || ret_t->tag == AbortType
+ || get_type(body_scope, stmt->ast)->tag == ReturnType)
+ body = Texts(body, compile_statement(body_scope, stmt->ast), "\n");
+ else body = Texts(body, compile_statement(body_scope, FakeAST(Return, stmt->ast)), "\n");
+ bind_statement(body_scope, stmt->ast);
+ }
+ if ((ret_t->tag == VoidType || ret_t->tag == AbortType) && body_scope->deferred)
+ body = Texts(body, compile_statement(body_scope, FakeAST(Return)), "\n");
+
+ env->code->lambdas = Texts(env->code->lambdas, code, " {\n", body, "\n}\n");
+ return Texts("((Closure_t){", name, ", ", userdata, "})");
+}
+
+static void add_closed_vars(Table_t *closed_vars, env_t *enclosing_scope, env_t *env, ast_t *ast) {
+ if (ast == NULL) return;
+
+ switch (ast->tag) {
+ case Var: {
+ binding_t *b = get_binding(enclosing_scope, Match(ast, Var)->name);
+ if (b) {
+ binding_t *shadow = get_binding(env, Match(ast, Var)->name);
+ if (!shadow || shadow == b) Table$str_set(closed_vars, Match(ast, Var)->name, b);
+ }
+ break;
+ }
+ case TextJoin: {
+ for (ast_list_t *child = Match(ast, TextJoin)->children; child; child = child->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, child->ast);
+ break;
+ }
+ case Declare: {
+ ast_t *value = Match(ast, Declare)->value;
+ add_closed_vars(closed_vars, enclosing_scope, env, value);
+ bind_statement(env, ast);
+ break;
+ }
+ case Assign: {
+ for (ast_list_t *target = Match(ast, Assign)->targets; target; target = target->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, target->ast);
+ for (ast_list_t *value = Match(ast, Assign)->values; value; value = value->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, value->ast);
+ break;
+ }
+ case BINOP_CASES: {
+ binary_operands_t binop = BINARY_OPERANDS(ast);
+ add_closed_vars(closed_vars, enclosing_scope, env, binop.lhs);
+ add_closed_vars(closed_vars, enclosing_scope, env, binop.rhs);
+ break;
+ }
+ case Not:
+ case Negative:
+ case HeapAllocate:
+ case StackReference: {
+ // UNSAFE:
+ ast_t *value = ast->__data.Not.value;
+ // END UNSAFE
+ add_closed_vars(closed_vars, enclosing_scope, env, value);
+ break;
+ }
+ case Min: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Min)->lhs);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Min)->rhs);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Min)->key);
+ break;
+ }
+ case Max: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Max)->lhs);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Max)->rhs);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Max)->key);
+ break;
+ }
+ case List: {
+ for (ast_list_t *item = Match(ast, List)->items; item; item = item->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, item->ast);
+ break;
+ }
+ case Set: {
+ for (ast_list_t *item = Match(ast, Set)->items; item; item = item->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, item->ast);
+ break;
+ }
+ case Table: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Table)->default_value);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Table)->fallback);
+ for (ast_list_t *entry = Match(ast, Table)->entries; entry; entry = entry->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, entry->ast);
+ break;
+ }
+ case TableEntry: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, TableEntry)->key);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, TableEntry)->value);
+ break;
+ }
+ case Comprehension: {
+ DeclareMatch(comp, ast, Comprehension);
+ if (comp->expr->tag == Comprehension) { // Nested comprehension
+ ast_t *body = comp->filter ? WrapAST(ast, If, .condition = comp->filter, .body = comp->expr) : comp->expr;
+ ast_t *loop = WrapAST(ast, For, .vars = comp->vars, .iter = comp->iter, .body = body);
+ return add_closed_vars(closed_vars, enclosing_scope, env, loop);
+ }
+
+ // List/Set/Table comprehension:
+ ast_t *body = comp->expr;
+ if (comp->filter) body = WrapAST(comp->expr, If, .condition = comp->filter, .body = body);
+ ast_t *loop = WrapAST(ast, For, .vars = comp->vars, .iter = comp->iter, .body = body);
+ add_closed_vars(closed_vars, enclosing_scope, env, loop);
+ break;
+ }
+ case Lambda: {
+ DeclareMatch(lambda, ast, Lambda);
+ env_t *lambda_scope = fresh_scope(env);
+ for (arg_ast_t *arg = lambda->args; arg; arg = arg->next)
+ set_binding(lambda_scope, arg->name, get_arg_ast_type(env, arg), Texts("_$", arg->name));
+ add_closed_vars(closed_vars, enclosing_scope, lambda_scope, lambda->body);
+ break;
+ }
+ case FunctionCall: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, FunctionCall)->fn);
+ for (arg_ast_t *arg = Match(ast, FunctionCall)->args; arg; arg = arg->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, arg->value);
+ break;
+ }
+ case MethodCall: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, MethodCall)->self);
+ for (arg_ast_t *arg = Match(ast, MethodCall)->args; arg; arg = arg->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, arg->value);
+ break;
+ }
+ case Block: {
+ env = fresh_scope(env);
+ for (ast_list_t *statement = Match(ast, Block)->statements; statement; statement = statement->next)
+ add_closed_vars(closed_vars, enclosing_scope, env, statement->ast);
+ break;
+ }
+ case For: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, For)->iter);
+ env_t *body_scope = for_scope(env, ast);
+ add_closed_vars(closed_vars, enclosing_scope, body_scope, Match(ast, For)->body);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, For)->empty);
+ break;
+ }
+ case While: {
+ DeclareMatch(while_, ast, While);
+ add_closed_vars(closed_vars, enclosing_scope, env, while_->condition);
+ env_t *scope = fresh_scope(env);
+ add_closed_vars(closed_vars, enclosing_scope, scope, while_->body);
+ break;
+ }
+ case If: {
+ DeclareMatch(if_, ast, If);
+ ast_t *condition = if_->condition;
+ if (condition->tag == Declare) {
+ env_t *truthy_scope = fresh_scope(env);
+ bind_statement(truthy_scope, condition);
+ if (!Match(condition, Declare)->value)
+ code_err(condition, "This declared variable must have an initial value");
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(condition, Declare)->value);
+ ast_t *var = Match(condition, Declare)->var;
+ type_t *cond_t = get_type(truthy_scope, var);
+ if (cond_t->tag == OptionalType) {
+ set_binding(truthy_scope, Match(var, Var)->name, Match(cond_t, OptionalType)->type, EMPTY_TEXT);
+ }
+ add_closed_vars(closed_vars, enclosing_scope, truthy_scope, if_->body);
+ add_closed_vars(closed_vars, enclosing_scope, env, if_->else_body);
+ } else {
+ add_closed_vars(closed_vars, enclosing_scope, env, condition);
+ env_t *truthy_scope = env;
+ type_t *cond_t = get_type(env, condition);
+ if (condition->tag == Var && cond_t->tag == OptionalType) {
+ truthy_scope = fresh_scope(env);
+ set_binding(truthy_scope, Match(condition, Var)->name, Match(cond_t, OptionalType)->type, EMPTY_TEXT);
+ }
+ add_closed_vars(closed_vars, enclosing_scope, truthy_scope, if_->body);
+ add_closed_vars(closed_vars, enclosing_scope, env, if_->else_body);
+ }
+ break;
+ }
+ case When: {
+ DeclareMatch(when, ast, When);
+ add_closed_vars(closed_vars, enclosing_scope, env, when->subject);
+ type_t *subject_t = get_type(env, when->subject);
+
+ if (subject_t->tag != EnumType) {
+ for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
+ add_closed_vars(closed_vars, enclosing_scope, env, clause->pattern);
+ add_closed_vars(closed_vars, enclosing_scope, env, clause->body);
+ }
+
+ if (when->else_body) add_closed_vars(closed_vars, enclosing_scope, env, when->else_body);
+ return;
+ }
+
+ DeclareMatch(enum_t, subject_t, EnumType);
+ for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
+ const char *clause_tag_name;
+ if (clause->pattern->tag == Var) clause_tag_name = Match(clause->pattern, Var)->name;
+ else if (clause->pattern->tag == FunctionCall && Match(clause->pattern, FunctionCall)->fn->tag == Var)
+ clause_tag_name = Match(Match(clause->pattern, FunctionCall)->fn, Var)->name;
+ else code_err(clause->pattern, "This is not a valid pattern for a ", type_to_str(subject_t), " enum");
+
+ type_t *tag_type = NULL;
+ for (tag_t *tag = enum_t->tags; tag; tag = tag->next) {
+ if (streq(tag->name, clause_tag_name)) {
+ tag_type = tag->type;
+ break;
+ }
+ }
+ assert(tag_type);
+ env_t *scope = when_clause_scope(env, subject_t, clause);
+ add_closed_vars(closed_vars, enclosing_scope, scope, clause->body);
+ }
+ if (when->else_body) add_closed_vars(closed_vars, enclosing_scope, env, when->else_body);
+ break;
+ }
+ case Repeat: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Repeat)->body);
+ break;
+ }
+ case Reduction: {
+ DeclareMatch(reduction, ast, Reduction);
+ static int64_t next_id = 1;
+ ast_t *item = FakeAST(Var, String("$it", next_id++));
+ ast_t *loop =
+ FakeAST(For, .vars = new (ast_list_t, .ast = item), .iter = reduction->iter, .body = FakeAST(Pass));
+ env_t *scope = for_scope(env, loop);
+ add_closed_vars(closed_vars, enclosing_scope, scope, reduction->key ? reduction->key : item);
+ break;
+ }
+ case Defer: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Defer)->body);
+ break;
+ }
+ case Return: {
+ ast_t *ret = Match(ast, Return)->value;
+ if (ret) add_closed_vars(closed_vars, enclosing_scope, env, ret);
+ break;
+ }
+ case Index: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Index)->indexed);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Index)->index);
+ break;
+ }
+ case FieldAccess: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, FieldAccess)->fielded);
+ break;
+ }
+ case Optional: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Optional)->value);
+ break;
+ }
+ case NonOptional: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, NonOptional)->value);
+ break;
+ }
+ case DocTest: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, DocTest)->expr);
+ break;
+ }
+ case Assert: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Assert)->expr);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Assert)->message);
+ break;
+ }
+ case Deserialize: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Deserialize)->value);
+ break;
+ }
+ case ExplicitlyTyped: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, ExplicitlyTyped)->ast);
+ break;
+ }
+ case Use:
+ case FunctionDef:
+ case ConvertDef:
+ case StructDef:
+ case EnumDef:
+ case LangDef:
+ case Extend: {
+ errx(1, "Definitions should not be reachable in a closure.");
+ }
+ default: break;
+ }
+}
+
+public
+Table_t get_closed_vars(env_t *env, arg_ast_t *args, ast_t *block) {
+ env_t *body_scope = fresh_scope(env);
+ for (arg_ast_t *arg = args; arg; arg = arg->next) {
+ type_t *arg_type = get_arg_ast_type(env, arg);
+ set_binding(body_scope, arg->name, arg_type, Texts("_$", arg->name));
+ }
+
+ Table_t closed_vars = {};
+ add_closed_vars(&closed_vars, env, body_scope, block);
+ return closed_vars;
+}
diff --git a/src/compile/functions.h b/src/compile/functions.h
index fdab495f..f7edd2aa 100644
--- a/src/compile/functions.h
+++ b/src/compile/functions.h
@@ -5,3 +5,5 @@
Text_t compile_function_call(env_t *env, ast_t *ast);
Text_t compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t *call_args);
+Text_t compile_lambda(env_t *env, ast_t *ast);
+Table_t get_closed_vars(env_t *env, arg_ast_t *args, ast_t *block);