diff options
| author | Bruce Hill <bruce@bruce-hill.com> | 2024-03-09 16:03:38 -0500 |
|---|---|---|
| committer | Bruce Hill <bruce@bruce-hill.com> | 2024-03-09 16:03:38 -0500 |
| commit | b04a1b30903636bf97a8e04efe88e8a177855bcb (patch) | |
| tree | d85a765c07c3cc686485d30ef7a4216f4d830cad | |
| parent | 42da91936e50ab652d140677689b519fe9deb3fe (diff) | |
Implement lambdas and closures
| -rw-r--r-- | compile.c | 137 | ||||
| -rw-r--r-- | environment.c | 19 | ||||
| -rw-r--r-- | environment.h | 8 | ||||
| -rw-r--r-- | test/lambdas.tm | 18 | ||||
| -rw-r--r-- | typecheck.c | 4 | ||||
| -rw-r--r-- | types.c | 21 |
6 files changed, 159 insertions, 48 deletions
@@ -26,6 +26,37 @@ CORD compile_type_ast(type_ast_t *t) } } +static bool promote(env_t *env, CORD *code, type_t *actual, type_t *needed) +{ + if (type_eq(actual, needed)) + return true; + + if (!can_promote(actual, needed)) + return false; + + if (actual->tag == IntType || actual->tag == NumType) + return true; + + // Automatic dereferencing: + if (actual->tag == PointerType && !Match(actual, PointerType)->is_optional + && can_promote(Match(actual, PointerType)->pointed, needed)) { + *code = CORD_all("*(", *code, ")"); + return promote(env, code, Match(actual, PointerType)->pointed, needed); + } + + // Optional promotion: + if (actual->tag == PointerType && needed->tag == PointerType) + return true; + + if (needed->tag == ClosureType && actual->tag == FunctionType && type_eq(actual, Match(needed, ClosureType)->fn)) { + *code = CORD_all("(closure_t){", *code, ", NULL}"); + return true; + } + + return false; +} + + CORD compile_declaration(type_t *t, const char *name) { if (t->tag == FunctionType) { @@ -175,19 +206,19 @@ static CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg { table_t used_args = {}; CORD code = CORD_EMPTY; - env_t *default_scope = fresh_scope(env); - default_scope->locals->fallback = env->globals; + env_t *default_scope = global_scope(env); for (arg_t *spec_arg = spec_args; spec_arg; spec_arg = spec_arg->next) { // Find keyword: if (spec_arg->name) { for (arg_ast_t *call_arg = call_args; call_arg; call_arg = call_arg->next) { if (call_arg->name && streq(call_arg->name, spec_arg->name)) { type_t *actual_t = get_type(env, call_arg->value); - if (!can_promote(actual_t, spec_arg->type)) + CORD value = compile(env, call_arg->value); + if (!promote(env, &value, actual_t, spec_arg->type)) code_err(call_arg->value, "This argument is supposed to be a %T, but this value is a %T", spec_arg->type, actual_t); Table_str_set(&used_args, call_arg->name, call_arg); if (code) code = CORD_cat(code, ", "); - code = CORD_cat(code, compile(env, call_arg->value)); + code = CORD_cat(code, value); goto found_it; } } @@ -199,11 +230,12 @@ static CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg const char *pseudoname = heap_strf("%ld", i++); if (!Table_str_get(used_args, pseudoname)) { type_t *actual_t = get_type(env, call_arg->value); - if (!can_promote(actual_t, spec_arg->type)) + CORD value = compile(env, call_arg->value); + if (!promote(env, &value, actual_t, spec_arg->type)) code_err(call_arg->value, "This argument is supposed to be a %T, but this value is a %T", spec_arg->type, actual_t); Table_str_set(&used_args, pseudoname, call_arg); if (code) code = CORD_cat(code, ", "); - code = CORD_cat(code, compile(env, call_arg->value)); + code = CORD_cat(code, value); goto found_it; } } @@ -299,9 +331,9 @@ CORD compile(env_t *env, ast_t *ast) type_t *lhs_t = get_type(env, binop->lhs); type_t *rhs_t = get_type(env, binop->rhs); type_t *operand_t; - if (can_promote(rhs_t, lhs_t)) + if (promote(env, &rhs, rhs_t, lhs_t)) operand_t = lhs_t; - else if (can_promote(lhs_t, rhs_t)) + else if (promote(env, &lhs, lhs_t, rhs_t)) operand_t = rhs_t; else code_err(ast, "I can't do operations between %T and %T", lhs_t, rhs_t); @@ -450,11 +482,11 @@ CORD compile(env_t *env, ast_t *ast) type_t *lhs_t = get_type(env, update->lhs); type_t *rhs_t = get_type(env, update->rhs); type_t *operand_t; - if (can_promote(rhs_t, lhs_t)) + if (promote(env, &rhs, rhs_t, lhs_t)) operand_t = lhs_t; - else if (can_promote(lhs_t, rhs_t)) + else if (promote(env, &lhs, lhs_t, rhs_t)) operand_t = rhs_t; - else if (lhs_t->tag == ArrayType && can_promote(rhs_t, Match(lhs_t, ArrayType)->item_type)) + else if (lhs_t->tag == ArrayType && promote(env, &rhs, rhs_t, Match(lhs_t, ArrayType)->item_type)) operand_t = lhs_t; else code_err(ast, "I can't do operations between %T and %T", lhs_t, rhs_t); @@ -497,7 +529,7 @@ CORD compile(env_t *env, ast_t *ast) if (operand_t->tag == TextType) { return CORD_asprintf("%r = CORD_cat(%r, %r);", lhs, lhs, rhs); } else if (operand_t->tag == ArrayType) { - if (can_promote(rhs_t, Match(lhs_t, ArrayType)->item_type)) { + if (promote(env, &rhs, rhs_t, Match(lhs_t, ArrayType)->item_type)) { // arr ++= item if (update->lhs->tag == Var) return CORD_all("Array__insert(&", lhs, ", $stack(", rhs, "), 0, ", compile_type_info(env, operand_t), ")"); @@ -691,7 +723,8 @@ CORD compile(env_t *env, ast_t *ast) case FunctionDef: { auto fndef = Match(ast, FunctionDef); CORD name = compile(env, fndef->name); - CORD signature = CORD_all(fndef->ret_type ? compile_type_ast(fndef->ret_type) : "void", " ", name, "("); + type_t *ret_t = fndef->ret_type ? parse_type_ast(env, fndef->ret_type) : Type(VoidType); + CORD signature = CORD_all(compile_type(ret_t), " ", name, "("); for (arg_ast_t *arg = fndef->args; arg; arg = arg->next) { type_t *arg_type = get_arg_ast_type(env, arg); signature = CORD_cat(signature, compile_declaration(arg_type, arg->name)); @@ -709,13 +742,19 @@ CORD compile(env_t *env, ast_t *ast) if (!fndef->is_private) code = CORD_cat("public ", code); - env_t *body_scope = fresh_scope(env); - body_scope->locals->fallback = env->globals; + env_t *body_scope = global_scope(env); for (arg_ast_t *arg = fndef->args; arg; arg = arg->next) { type_t *arg_type = get_arg_ast_type(env, arg); set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=arg->name)); } + fn_context_t fn_ctx = (fn_context_t){ + .return_type=ret_t, + .closure_scope=NULL, + .closed_vars=NULL, + }; + body_scope->fn_ctx = &fn_ctx; + CORD body = compile(body_scope, fndef->body); if (CORD_fetch(body, 0) != '{') body = CORD_asprintf("{\n%r\n}", body); @@ -728,19 +767,47 @@ CORD compile(env_t *env, ast_t *ast) CORD name = CORD_asprintf("lambda$%ld", lambda_number++); env_t *body_scope = fresh_scope(env); - body_scope->locals->fallback = env->globals; for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) { type_t *arg_type = get_arg_ast_type(env, arg); set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=arg->name)); } + type_t *ret_t = get_type(body_scope, lambda->body); + fn_context_t fn_ctx = (fn_context_t){ + .return_type=ret_t, + .closure_scope=body_scope->locals->fallback, + .closed_vars=new(table_t), + }; + body_scope->fn_ctx = &fn_ctx; + body_scope->locals->fallback = env->globals; + CORD code = CORD_all("static ", compile_type(ret_t), " ", name, "("); for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) { type_t *arg_type = get_arg_ast_type(env, arg); code = CORD_all(code, compile_type(arg_type), " ", arg->name, ", "); } - code = CORD_cat(code, "void *$userdata)"); + + for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next) + (void)compile_statement(body_scope, stmt->ast); + + CORD userdata; + if (Table_length(*fn_ctx.closed_vars) == 0) { + code = CORD_cat(code, "void *$userdata)"); + userdata = "NULL"; + } else { + CORD def = "typedef struct {"; + userdata = CORD_all("new(", name, "$userdata_t"); + for (int64_t i = 1; i <= Table_length(*fn_ctx.closed_vars); i++) { + struct { const char *name; binding_t *b; } *entry = Table_entry(*fn_ctx.closed_vars, i); + def = CORD_all(def, compile_declaration(entry->b->type, entry->name), "; "); + userdata = CORD_all(userdata, ", ", entry->b->code); + } + userdata = CORD_all(userdata, ")"); + def = CORD_all(def, "} ", name, "$userdata_t;"); + env->code->typedefs = CORD_cat(env->code->typedefs, def); + code = CORD_all(code, name, "$userdata_t *$userdata)"); + } CORD body = CORD_EMPTY; for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next) { @@ -750,7 +817,7 @@ CORD compile(env_t *env, ast_t *ast) body = CORD_all(body, compile_statement(body_scope, FakeAST(Return, stmt->ast)), "\n"); } env->code->funcs = CORD_all(env->code->funcs, code, " {\n", body, "\n}"); - return CORD_all("(closure_t){", name, ", NULL}"); + return CORD_all("(closure_t){", name, ", ", userdata, "}"); } case MethodCall: { auto call = Match(ast, MethodCall); @@ -848,13 +915,22 @@ CORD compile(env_t *env, ast_t *ast) } else if (fn_t->tag == ClosureType) { fn_t = Match(fn_t, ClosureType)->fn; arg_t *type_args = Match(fn_t, FunctionType)->args; - CORD fn_type_code = compile_type(fn_t); + + arg_t *closure_fn_args = NULL; + for (arg_t *arg = Match(fn_t, FunctionType)->args; arg; arg = arg->next) + closure_fn_args = new(arg_t, .name=arg->name, .type=arg->type, .default_val=arg->default_val, .next=closure_fn_args); + closure_fn_args = new(arg_t, .name="$userdata", .type=Type(PointerType, .pointed=Type(MemoryType)), .next=closure_fn_args); + REVERSE_LIST(closure_fn_args); + CORD fn_type_code = compile_type(Type(FunctionType, .args=closure_fn_args, .ret=Match(fn_t, FunctionType)->ret)); + CORD closure = compile(env, call->fn); + CORD arg_code = compile_arguments(env, ast, type_args, call->args); + if (arg_code) arg_code = CORD_cat(arg_code, ", "); if (call->fn->tag == Var) { - return CORD_all("((", fn_type_code, ")", closure, ".fn)(", compile_arguments(env, ast, type_args, call->args), ")"); + return CORD_all("((", fn_type_code, ")", closure, ".fn)(", arg_code, closure, ".userdata)"); } else { return CORD_all("({ closure_t $closure = ", closure, "; ((", fn_type_code, ")$closure.fn)(", - compile_arguments(env, ast, type_args, call->args), "); })"); + arg_code, "$closure.userdata); })"); } } else { code_err(call->fn, "This is not a function, it's a %T", fn_t); @@ -1062,8 +1138,21 @@ CORD compile(env_t *env, ast_t *ast) } case Pass: return ";"; case Return: { + if (!env->fn_ctx) code_err(ast, "This return statement is not inside any function"); auto ret = Match(ast, Return)->value; - return ret ? CORD_asprintf("return %r;", compile(env, ret)) : "return;"; + assert(env->fn_ctx->return_type); + if (ret) { + type_t *ret_t = get_type(env, ret); + CORD value = compile(env, ret); + if (!promote(env, &value, ret_t, env->fn_ctx->return_type)) + code_err(ast, "This function expects a return value of type %T, but this return has type %T", + env->fn_ctx->return_type, ret_t); + return CORD_all("return ", value, ";"); + } else { + if (env->fn_ctx->return_type->tag != VoidType) + code_err(ast, "This function expects a return value of type %T", env->fn_ctx->return_type->tag); + return "return;"; + } } // Extern, case StructDef: { @@ -1251,10 +1340,10 @@ CORD compile(env_t *env, ast_t *ast) case TableType: { type_t *key_t = Match(container_t, TableType)->key_type; type_t *value_t = Match(container_t, TableType)->value_type; - if (!can_promote(index_t, key_t)) - code_err(indexing->index, "This value has type %T, but this table can only be index with keys of type %T", index_t, key_t); CORD table = compile_to_pointer_depth(env, indexing->indexed, 0, false); CORD key = compile(env, indexing->index); + if (!promote(env, &key, index_t, key_t)) + code_err(indexing->index, "This value has type %T, but this table can only be index with keys of type %T", index_t, key_t); file_t *f = indexing->index->file; return CORD_all("$Table_get(", table, ", ", compile_type(key_t), ", ", compile_type(value_t), ", ", key, ", ", compile_type_info(env, container_t), ", ", diff --git a/environment.c b/environment.c index 54b25dbd..d110a7ef 100644 --- a/environment.c +++ b/environment.c @@ -176,6 +176,7 @@ env_t *new_compilation_unit(void) env_t *ns_env = namespace_env(env, global_types[i].name); $ARRAY_FOREACH(global_types[i].namespace, j, ns_entry_t, entry, { type_t *type = parse_type_string(ns_env, entry.type_str); + if (type->tag == ClosureType) type = Match(type, ClosureType)->fn; binding_t *b = new(binding_t, .code=entry.code, .type=type); Table_str_set(namespace, entry.name, b); }, {}) @@ -184,6 +185,14 @@ env_t *new_compilation_unit(void) return env; } +env_t *global_scope(env_t *env) +{ + env_t *scope = new(env_t); + *scope = *env; + scope->locals = new(table_t, .fallback=env->globals); + return scope; +} + env_t *fresh_scope(env_t *env) { env_t *scope = new(env_t); @@ -207,7 +216,15 @@ env_t *namespace_env(env_t *env, const char *namespace_name) binding_t *get_binding(env_t *env, const char *name) { - return Table_str_get(*env->locals, name); + binding_t *b = Table_str_get(*env->locals, name); + if (!b && env->fn_ctx && env->fn_ctx->closure_scope) { + b = Table_str_get(*env->fn_ctx->closure_scope, name); + if (b) { + Table_str_set(env->fn_ctx->closed_vars, name, b); + return new(binding_t, .type=b->type, .code=CORD_all("$userdata->", name)); + } + } + return b; } binding_t *get_namespace_binding(env_t *env, ast_t *self, const char *name) diff --git a/environment.h b/environment.h index 8660fe68..2ab7a9d2 100644 --- a/environment.h +++ b/environment.h @@ -17,9 +17,16 @@ typedef struct { } compilation_unit_t; typedef struct { + type_t *return_type; + table_t *closure_scope; + table_t *closed_vars; +} fn_context_t; + +typedef struct { table_t *types, *globals, *locals; table_t *type_namespaces; // Map of type name -> namespace table compilation_unit_t *code; + fn_context_t *fn_ctx; CORD scope_prefix; } env_t; @@ -29,6 +36,7 @@ typedef struct { } binding_t; env_t *new_compilation_unit(void); +env_t *global_scope(env_t *env); env_t *fresh_scope(env_t *env); env_t *namespace_env(env_t *env, const char *namespace_name); __attribute__((noreturn)) diff --git a/test/lambdas.tm b/test/lambdas.tm index 86b4b263..d896ea4b 100644 --- a/test/lambdas.tm +++ b/test/lambdas.tm @@ -1,5 +1,3 @@ - - >> add_one := func(x:Int) x + 1 >> add_one(10) = 11 @@ -10,3 +8,19 @@ >> asdf := add_one >> asdf(99) = 100 + + +func make_adder(x:Int)-> func(y:Int)->Int + return func(y:Int) x + y + +>> add_100 := make_adder(100) +>> add_100(5) += 105 + + +func suffix_fn(fn:func(t:Text)->Text, suffix:Text)->func(t:Text)->Text + return func(t:Text) fn(t)++suffix + +>> shout2 := suffix_fn(Text.upper, "!") +>> shout2("hello") += "HELLO!" diff --git a/typecheck.c b/typecheck.c index 88b9d892..671326d0 100644 --- a/typecheck.c +++ b/typecheck.c @@ -67,7 +67,7 @@ type_t *parse_type_ast(env_t *env, type_ast_t *ast) } } REVERSE_LIST(type_args); - return Type(FunctionType, .args=type_args, .ret=ret_t); + return Type(ClosureType, Type(FunctionType, .args=type_args, .ret=ret_t)); } case UnknownTypeAST: code_err(ast, "I don't know how to get this type"); } @@ -610,7 +610,7 @@ type_t *get_type(env_t *env, ast_t *ast) case Lambda: { auto lambda = Match(ast, Lambda); arg_t *args = NULL; - env_t *scope = fresh_scope(env); + env_t *scope = fresh_scope(env); // For now, just use closed variables in scope normally for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) { type_t *t = arg->type ? parse_type_ast(env, arg->type) : get_type(env, arg->value); args = new(arg_t, .name=arg->name, .type=t, .next=args); @@ -28,7 +28,7 @@ CORD type_to_cord(type_t *t) { return CORD_asprintf("{%r=>%r}", type_to_cord(table->key_type), type_to_cord(table->value_type)); } case ClosureType: { - return CORD_all("~", type_to_cord(Match(t, ClosureType)->fn)); + return type_to_cord(Match(t, ClosureType)->fn); } case FunctionType: { CORD c = "func("; @@ -261,24 +261,7 @@ bool can_promote(type_t *actual, type_t *needed) } if (needed->tag == ClosureType && actual->tag == FunctionType) - return can_promote(actual, Match(needed, ClosureType)->fn); - - // Function promotion: - if (needed->tag == FunctionType && actual->tag == FunctionType) { - auto needed_fn = Match(needed, FunctionType); - auto actual_fn = Match(actual, FunctionType); - for (arg_t *needed_arg = needed_fn->args, *actual_arg = actual_fn->args; - needed_arg || actual_arg; - needed_arg = needed_arg->next, actual_arg = actual_arg->next) { - - if (!needed_arg || !actual_arg) - return false; - - if (!type_eq(needed_arg->type, actual_arg->type)) - return false; - } - return true; - } + return type_eq(actual, Match(needed, ClosureType)->fn); return false; } |
