aboutsummaryrefslogtreecommitdiff
path: root/compile.c
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2024-03-09 16:03:38 -0500
committerBruce Hill <bruce@bruce-hill.com>2024-03-09 16:03:38 -0500
commitb04a1b30903636bf97a8e04efe88e8a177855bcb (patch)
treed85a765c07c3cc686485d30ef7a4216f4d830cad /compile.c
parent42da91936e50ab652d140677689b519fe9deb3fe (diff)
Implement lambdas and closures
Diffstat (limited to 'compile.c')
-rw-r--r--compile.c137
1 files changed, 113 insertions, 24 deletions
diff --git a/compile.c b/compile.c
index 5d43bd24..7f70d366 100644
--- a/compile.c
+++ b/compile.c
@@ -26,6 +26,37 @@ CORD compile_type_ast(type_ast_t *t)
}
}
+static bool promote(env_t *env, CORD *code, type_t *actual, type_t *needed)
+{
+ if (type_eq(actual, needed))
+ return true;
+
+ if (!can_promote(actual, needed))
+ return false;
+
+ if (actual->tag == IntType || actual->tag == NumType)
+ return true;
+
+ // Automatic dereferencing:
+ if (actual->tag == PointerType && !Match(actual, PointerType)->is_optional
+ && can_promote(Match(actual, PointerType)->pointed, needed)) {
+ *code = CORD_all("*(", *code, ")");
+ return promote(env, code, Match(actual, PointerType)->pointed, needed);
+ }
+
+ // Optional promotion:
+ if (actual->tag == PointerType && needed->tag == PointerType)
+ return true;
+
+ if (needed->tag == ClosureType && actual->tag == FunctionType && type_eq(actual, Match(needed, ClosureType)->fn)) {
+ *code = CORD_all("(closure_t){", *code, ", NULL}");
+ return true;
+ }
+
+ return false;
+}
+
+
CORD compile_declaration(type_t *t, const char *name)
{
if (t->tag == FunctionType) {
@@ -175,19 +206,19 @@ static CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg
{
table_t used_args = {};
CORD code = CORD_EMPTY;
- env_t *default_scope = fresh_scope(env);
- default_scope->locals->fallback = env->globals;
+ env_t *default_scope = global_scope(env);
for (arg_t *spec_arg = spec_args; spec_arg; spec_arg = spec_arg->next) {
// Find keyword:
if (spec_arg->name) {
for (arg_ast_t *call_arg = call_args; call_arg; call_arg = call_arg->next) {
if (call_arg->name && streq(call_arg->name, spec_arg->name)) {
type_t *actual_t = get_type(env, call_arg->value);
- if (!can_promote(actual_t, spec_arg->type))
+ CORD value = compile(env, call_arg->value);
+ if (!promote(env, &value, actual_t, spec_arg->type))
code_err(call_arg->value, "This argument is supposed to be a %T, but this value is a %T", spec_arg->type, actual_t);
Table_str_set(&used_args, call_arg->name, call_arg);
if (code) code = CORD_cat(code, ", ");
- code = CORD_cat(code, compile(env, call_arg->value));
+ code = CORD_cat(code, value);
goto found_it;
}
}
@@ -199,11 +230,12 @@ static CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg
const char *pseudoname = heap_strf("%ld", i++);
if (!Table_str_get(used_args, pseudoname)) {
type_t *actual_t = get_type(env, call_arg->value);
- if (!can_promote(actual_t, spec_arg->type))
+ CORD value = compile(env, call_arg->value);
+ if (!promote(env, &value, actual_t, spec_arg->type))
code_err(call_arg->value, "This argument is supposed to be a %T, but this value is a %T", spec_arg->type, actual_t);
Table_str_set(&used_args, pseudoname, call_arg);
if (code) code = CORD_cat(code, ", ");
- code = CORD_cat(code, compile(env, call_arg->value));
+ code = CORD_cat(code, value);
goto found_it;
}
}
@@ -299,9 +331,9 @@ CORD compile(env_t *env, ast_t *ast)
type_t *lhs_t = get_type(env, binop->lhs);
type_t *rhs_t = get_type(env, binop->rhs);
type_t *operand_t;
- if (can_promote(rhs_t, lhs_t))
+ if (promote(env, &rhs, rhs_t, lhs_t))
operand_t = lhs_t;
- else if (can_promote(lhs_t, rhs_t))
+ else if (promote(env, &lhs, lhs_t, rhs_t))
operand_t = rhs_t;
else
code_err(ast, "I can't do operations between %T and %T", lhs_t, rhs_t);
@@ -450,11 +482,11 @@ CORD compile(env_t *env, ast_t *ast)
type_t *lhs_t = get_type(env, update->lhs);
type_t *rhs_t = get_type(env, update->rhs);
type_t *operand_t;
- if (can_promote(rhs_t, lhs_t))
+ if (promote(env, &rhs, rhs_t, lhs_t))
operand_t = lhs_t;
- else if (can_promote(lhs_t, rhs_t))
+ else if (promote(env, &lhs, lhs_t, rhs_t))
operand_t = rhs_t;
- else if (lhs_t->tag == ArrayType && can_promote(rhs_t, Match(lhs_t, ArrayType)->item_type))
+ else if (lhs_t->tag == ArrayType && promote(env, &rhs, rhs_t, Match(lhs_t, ArrayType)->item_type))
operand_t = lhs_t;
else
code_err(ast, "I can't do operations between %T and %T", lhs_t, rhs_t);
@@ -497,7 +529,7 @@ CORD compile(env_t *env, ast_t *ast)
if (operand_t->tag == TextType) {
return CORD_asprintf("%r = CORD_cat(%r, %r);", lhs, lhs, rhs);
} else if (operand_t->tag == ArrayType) {
- if (can_promote(rhs_t, Match(lhs_t, ArrayType)->item_type)) {
+ if (promote(env, &rhs, rhs_t, Match(lhs_t, ArrayType)->item_type)) {
// arr ++= item
if (update->lhs->tag == Var)
return CORD_all("Array__insert(&", lhs, ", $stack(", rhs, "), 0, ", compile_type_info(env, operand_t), ")");
@@ -691,7 +723,8 @@ CORD compile(env_t *env, ast_t *ast)
case FunctionDef: {
auto fndef = Match(ast, FunctionDef);
CORD name = compile(env, fndef->name);
- CORD signature = CORD_all(fndef->ret_type ? compile_type_ast(fndef->ret_type) : "void", " ", name, "(");
+ type_t *ret_t = fndef->ret_type ? parse_type_ast(env, fndef->ret_type) : Type(VoidType);
+ CORD signature = CORD_all(compile_type(ret_t), " ", name, "(");
for (arg_ast_t *arg = fndef->args; arg; arg = arg->next) {
type_t *arg_type = get_arg_ast_type(env, arg);
signature = CORD_cat(signature, compile_declaration(arg_type, arg->name));
@@ -709,13 +742,19 @@ CORD compile(env_t *env, ast_t *ast)
if (!fndef->is_private)
code = CORD_cat("public ", code);
- env_t *body_scope = fresh_scope(env);
- body_scope->locals->fallback = env->globals;
+ env_t *body_scope = global_scope(env);
for (arg_ast_t *arg = fndef->args; arg; arg = arg->next) {
type_t *arg_type = get_arg_ast_type(env, arg);
set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=arg->name));
}
+ fn_context_t fn_ctx = (fn_context_t){
+ .return_type=ret_t,
+ .closure_scope=NULL,
+ .closed_vars=NULL,
+ };
+ body_scope->fn_ctx = &fn_ctx;
+
CORD body = compile(body_scope, fndef->body);
if (CORD_fetch(body, 0) != '{')
body = CORD_asprintf("{\n%r\n}", body);
@@ -728,19 +767,47 @@ CORD compile(env_t *env, ast_t *ast)
CORD name = CORD_asprintf("lambda$%ld", lambda_number++);
env_t *body_scope = fresh_scope(env);
- body_scope->locals->fallback = env->globals;
for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) {
type_t *arg_type = get_arg_ast_type(env, arg);
set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=arg->name));
}
+
type_t *ret_t = get_type(body_scope, lambda->body);
+ fn_context_t fn_ctx = (fn_context_t){
+ .return_type=ret_t,
+ .closure_scope=body_scope->locals->fallback,
+ .closed_vars=new(table_t),
+ };
+ body_scope->fn_ctx = &fn_ctx;
+ body_scope->locals->fallback = env->globals;
+
CORD code = CORD_all("static ", compile_type(ret_t), " ", name, "(");
for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) {
type_t *arg_type = get_arg_ast_type(env, arg);
code = CORD_all(code, compile_type(arg_type), " ", arg->name, ", ");
}
- code = CORD_cat(code, "void *$userdata)");
+
+ for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next)
+ (void)compile_statement(body_scope, stmt->ast);
+
+ CORD userdata;
+ if (Table_length(*fn_ctx.closed_vars) == 0) {
+ code = CORD_cat(code, "void *$userdata)");
+ userdata = "NULL";
+ } else {
+ CORD def = "typedef struct {";
+ userdata = CORD_all("new(", name, "$userdata_t");
+ for (int64_t i = 1; i <= Table_length(*fn_ctx.closed_vars); i++) {
+ struct { const char *name; binding_t *b; } *entry = Table_entry(*fn_ctx.closed_vars, i);
+ def = CORD_all(def, compile_declaration(entry->b->type, entry->name), "; ");
+ userdata = CORD_all(userdata, ", ", entry->b->code);
+ }
+ userdata = CORD_all(userdata, ")");
+ def = CORD_all(def, "} ", name, "$userdata_t;");
+ env->code->typedefs = CORD_cat(env->code->typedefs, def);
+ code = CORD_all(code, name, "$userdata_t *$userdata)");
+ }
CORD body = CORD_EMPTY;
for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next) {
@@ -750,7 +817,7 @@ CORD compile(env_t *env, ast_t *ast)
body = CORD_all(body, compile_statement(body_scope, FakeAST(Return, stmt->ast)), "\n");
}
env->code->funcs = CORD_all(env->code->funcs, code, " {\n", body, "\n}");
- return CORD_all("(closure_t){", name, ", NULL}");
+ return CORD_all("(closure_t){", name, ", ", userdata, "}");
}
case MethodCall: {
auto call = Match(ast, MethodCall);
@@ -848,13 +915,22 @@ CORD compile(env_t *env, ast_t *ast)
} else if (fn_t->tag == ClosureType) {
fn_t = Match(fn_t, ClosureType)->fn;
arg_t *type_args = Match(fn_t, FunctionType)->args;
- CORD fn_type_code = compile_type(fn_t);
+
+ arg_t *closure_fn_args = NULL;
+ for (arg_t *arg = Match(fn_t, FunctionType)->args; arg; arg = arg->next)
+ closure_fn_args = new(arg_t, .name=arg->name, .type=arg->type, .default_val=arg->default_val, .next=closure_fn_args);
+ closure_fn_args = new(arg_t, .name="$userdata", .type=Type(PointerType, .pointed=Type(MemoryType)), .next=closure_fn_args);
+ REVERSE_LIST(closure_fn_args);
+ CORD fn_type_code = compile_type(Type(FunctionType, .args=closure_fn_args, .ret=Match(fn_t, FunctionType)->ret));
+
CORD closure = compile(env, call->fn);
+ CORD arg_code = compile_arguments(env, ast, type_args, call->args);
+ if (arg_code) arg_code = CORD_cat(arg_code, ", ");
if (call->fn->tag == Var) {
- return CORD_all("((", fn_type_code, ")", closure, ".fn)(", compile_arguments(env, ast, type_args, call->args), ")");
+ return CORD_all("((", fn_type_code, ")", closure, ".fn)(", arg_code, closure, ".userdata)");
} else {
return CORD_all("({ closure_t $closure = ", closure, "; ((", fn_type_code, ")$closure.fn)(",
- compile_arguments(env, ast, type_args, call->args), "); })");
+ arg_code, "$closure.userdata); })");
}
} else {
code_err(call->fn, "This is not a function, it's a %T", fn_t);
@@ -1062,8 +1138,21 @@ CORD compile(env_t *env, ast_t *ast)
}
case Pass: return ";";
case Return: {
+ if (!env->fn_ctx) code_err(ast, "This return statement is not inside any function");
auto ret = Match(ast, Return)->value;
- return ret ? CORD_asprintf("return %r;", compile(env, ret)) : "return;";
+ assert(env->fn_ctx->return_type);
+ if (ret) {
+ type_t *ret_t = get_type(env, ret);
+ CORD value = compile(env, ret);
+ if (!promote(env, &value, ret_t, env->fn_ctx->return_type))
+ code_err(ast, "This function expects a return value of type %T, but this return has type %T",
+ env->fn_ctx->return_type, ret_t);
+ return CORD_all("return ", value, ";");
+ } else {
+ if (env->fn_ctx->return_type->tag != VoidType)
+ code_err(ast, "This function expects a return value of type %T", env->fn_ctx->return_type->tag);
+ return "return;";
+ }
}
// Extern,
case StructDef: {
@@ -1251,10 +1340,10 @@ CORD compile(env_t *env, ast_t *ast)
case TableType: {
type_t *key_t = Match(container_t, TableType)->key_type;
type_t *value_t = Match(container_t, TableType)->value_type;
- if (!can_promote(index_t, key_t))
- code_err(indexing->index, "This value has type %T, but this table can only be index with keys of type %T", index_t, key_t);
CORD table = compile_to_pointer_depth(env, indexing->indexed, 0, false);
CORD key = compile(env, indexing->index);
+ if (!promote(env, &key, index_t, key_t))
+ code_err(indexing->index, "This value has type %T, but this table can only be index with keys of type %T", index_t, key_t);
file_t *f = indexing->index->file;
return CORD_all("$Table_get(", table, ", ", compile_type(key_t), ", ", compile_type(value_t), ", ",
key, ", ", compile_type_info(env, container_t), ", ",