diff options
Diffstat (limited to 'compile.c')
| -rw-r--r-- | compile.c | 137 |
1 files changed, 113 insertions, 24 deletions
@@ -26,6 +26,37 @@ CORD compile_type_ast(type_ast_t *t) } } +static bool promote(env_t *env, CORD *code, type_t *actual, type_t *needed) +{ + if (type_eq(actual, needed)) + return true; + + if (!can_promote(actual, needed)) + return false; + + if (actual->tag == IntType || actual->tag == NumType) + return true; + + // Automatic dereferencing: + if (actual->tag == PointerType && !Match(actual, PointerType)->is_optional + && can_promote(Match(actual, PointerType)->pointed, needed)) { + *code = CORD_all("*(", *code, ")"); + return promote(env, code, Match(actual, PointerType)->pointed, needed); + } + + // Optional promotion: + if (actual->tag == PointerType && needed->tag == PointerType) + return true; + + if (needed->tag == ClosureType && actual->tag == FunctionType && type_eq(actual, Match(needed, ClosureType)->fn)) { + *code = CORD_all("(closure_t){", *code, ", NULL}"); + return true; + } + + return false; +} + + CORD compile_declaration(type_t *t, const char *name) { if (t->tag == FunctionType) { @@ -175,19 +206,19 @@ static CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg { table_t used_args = {}; CORD code = CORD_EMPTY; - env_t *default_scope = fresh_scope(env); - default_scope->locals->fallback = env->globals; + env_t *default_scope = global_scope(env); for (arg_t *spec_arg = spec_args; spec_arg; spec_arg = spec_arg->next) { // Find keyword: if (spec_arg->name) { for (arg_ast_t *call_arg = call_args; call_arg; call_arg = call_arg->next) { if (call_arg->name && streq(call_arg->name, spec_arg->name)) { type_t *actual_t = get_type(env, call_arg->value); - if (!can_promote(actual_t, spec_arg->type)) + CORD value = compile(env, call_arg->value); + if (!promote(env, &value, actual_t, spec_arg->type)) code_err(call_arg->value, "This argument is supposed to be a %T, but this value is a %T", spec_arg->type, actual_t); Table_str_set(&used_args, call_arg->name, call_arg); if (code) code = CORD_cat(code, ", "); - code = CORD_cat(code, compile(env, call_arg->value)); + code = CORD_cat(code, value); goto found_it; } } @@ -199,11 +230,12 @@ static CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg const char *pseudoname = heap_strf("%ld", i++); if (!Table_str_get(used_args, pseudoname)) { type_t *actual_t = get_type(env, call_arg->value); - if (!can_promote(actual_t, spec_arg->type)) + CORD value = compile(env, call_arg->value); + if (!promote(env, &value, actual_t, spec_arg->type)) code_err(call_arg->value, "This argument is supposed to be a %T, but this value is a %T", spec_arg->type, actual_t); Table_str_set(&used_args, pseudoname, call_arg); if (code) code = CORD_cat(code, ", "); - code = CORD_cat(code, compile(env, call_arg->value)); + code = CORD_cat(code, value); goto found_it; } } @@ -299,9 +331,9 @@ CORD compile(env_t *env, ast_t *ast) type_t *lhs_t = get_type(env, binop->lhs); type_t *rhs_t = get_type(env, binop->rhs); type_t *operand_t; - if (can_promote(rhs_t, lhs_t)) + if (promote(env, &rhs, rhs_t, lhs_t)) operand_t = lhs_t; - else if (can_promote(lhs_t, rhs_t)) + else if (promote(env, &lhs, lhs_t, rhs_t)) operand_t = rhs_t; else code_err(ast, "I can't do operations between %T and %T", lhs_t, rhs_t); @@ -450,11 +482,11 @@ CORD compile(env_t *env, ast_t *ast) type_t *lhs_t = get_type(env, update->lhs); type_t *rhs_t = get_type(env, update->rhs); type_t *operand_t; - if (can_promote(rhs_t, lhs_t)) + if (promote(env, &rhs, rhs_t, lhs_t)) operand_t = lhs_t; - else if (can_promote(lhs_t, rhs_t)) + else if (promote(env, &lhs, lhs_t, rhs_t)) operand_t = rhs_t; - else if (lhs_t->tag == ArrayType && can_promote(rhs_t, Match(lhs_t, ArrayType)->item_type)) + else if (lhs_t->tag == ArrayType && promote(env, &rhs, rhs_t, Match(lhs_t, ArrayType)->item_type)) operand_t = lhs_t; else code_err(ast, "I can't do operations between %T and %T", lhs_t, rhs_t); @@ -497,7 +529,7 @@ CORD compile(env_t *env, ast_t *ast) if (operand_t->tag == TextType) { return CORD_asprintf("%r = CORD_cat(%r, %r);", lhs, lhs, rhs); } else if (operand_t->tag == ArrayType) { - if (can_promote(rhs_t, Match(lhs_t, ArrayType)->item_type)) { + if (promote(env, &rhs, rhs_t, Match(lhs_t, ArrayType)->item_type)) { // arr ++= item if (update->lhs->tag == Var) return CORD_all("Array__insert(&", lhs, ", $stack(", rhs, "), 0, ", compile_type_info(env, operand_t), ")"); @@ -691,7 +723,8 @@ CORD compile(env_t *env, ast_t *ast) case FunctionDef: { auto fndef = Match(ast, FunctionDef); CORD name = compile(env, fndef->name); - CORD signature = CORD_all(fndef->ret_type ? compile_type_ast(fndef->ret_type) : "void", " ", name, "("); + type_t *ret_t = fndef->ret_type ? parse_type_ast(env, fndef->ret_type) : Type(VoidType); + CORD signature = CORD_all(compile_type(ret_t), " ", name, "("); for (arg_ast_t *arg = fndef->args; arg; arg = arg->next) { type_t *arg_type = get_arg_ast_type(env, arg); signature = CORD_cat(signature, compile_declaration(arg_type, arg->name)); @@ -709,13 +742,19 @@ CORD compile(env_t *env, ast_t *ast) if (!fndef->is_private) code = CORD_cat("public ", code); - env_t *body_scope = fresh_scope(env); - body_scope->locals->fallback = env->globals; + env_t *body_scope = global_scope(env); for (arg_ast_t *arg = fndef->args; arg; arg = arg->next) { type_t *arg_type = get_arg_ast_type(env, arg); set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=arg->name)); } + fn_context_t fn_ctx = (fn_context_t){ + .return_type=ret_t, + .closure_scope=NULL, + .closed_vars=NULL, + }; + body_scope->fn_ctx = &fn_ctx; + CORD body = compile(body_scope, fndef->body); if (CORD_fetch(body, 0) != '{') body = CORD_asprintf("{\n%r\n}", body); @@ -728,19 +767,47 @@ CORD compile(env_t *env, ast_t *ast) CORD name = CORD_asprintf("lambda$%ld", lambda_number++); env_t *body_scope = fresh_scope(env); - body_scope->locals->fallback = env->globals; for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) { type_t *arg_type = get_arg_ast_type(env, arg); set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=arg->name)); } + type_t *ret_t = get_type(body_scope, lambda->body); + fn_context_t fn_ctx = (fn_context_t){ + .return_type=ret_t, + .closure_scope=body_scope->locals->fallback, + .closed_vars=new(table_t), + }; + body_scope->fn_ctx = &fn_ctx; + body_scope->locals->fallback = env->globals; + CORD code = CORD_all("static ", compile_type(ret_t), " ", name, "("); for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) { type_t *arg_type = get_arg_ast_type(env, arg); code = CORD_all(code, compile_type(arg_type), " ", arg->name, ", "); } - code = CORD_cat(code, "void *$userdata)"); + + for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next) + (void)compile_statement(body_scope, stmt->ast); + + CORD userdata; + if (Table_length(*fn_ctx.closed_vars) == 0) { + code = CORD_cat(code, "void *$userdata)"); + userdata = "NULL"; + } else { + CORD def = "typedef struct {"; + userdata = CORD_all("new(", name, "$userdata_t"); + for (int64_t i = 1; i <= Table_length(*fn_ctx.closed_vars); i++) { + struct { const char *name; binding_t *b; } *entry = Table_entry(*fn_ctx.closed_vars, i); + def = CORD_all(def, compile_declaration(entry->b->type, entry->name), "; "); + userdata = CORD_all(userdata, ", ", entry->b->code); + } + userdata = CORD_all(userdata, ")"); + def = CORD_all(def, "} ", name, "$userdata_t;"); + env->code->typedefs = CORD_cat(env->code->typedefs, def); + code = CORD_all(code, name, "$userdata_t *$userdata)"); + } CORD body = CORD_EMPTY; for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next) { @@ -750,7 +817,7 @@ CORD compile(env_t *env, ast_t *ast) body = CORD_all(body, compile_statement(body_scope, FakeAST(Return, stmt->ast)), "\n"); } env->code->funcs = CORD_all(env->code->funcs, code, " {\n", body, "\n}"); - return CORD_all("(closure_t){", name, ", NULL}"); + return CORD_all("(closure_t){", name, ", ", userdata, "}"); } case MethodCall: { auto call = Match(ast, MethodCall); @@ -848,13 +915,22 @@ CORD compile(env_t *env, ast_t *ast) } else if (fn_t->tag == ClosureType) { fn_t = Match(fn_t, ClosureType)->fn; arg_t *type_args = Match(fn_t, FunctionType)->args; - CORD fn_type_code = compile_type(fn_t); + + arg_t *closure_fn_args = NULL; + for (arg_t *arg = Match(fn_t, FunctionType)->args; arg; arg = arg->next) + closure_fn_args = new(arg_t, .name=arg->name, .type=arg->type, .default_val=arg->default_val, .next=closure_fn_args); + closure_fn_args = new(arg_t, .name="$userdata", .type=Type(PointerType, .pointed=Type(MemoryType)), .next=closure_fn_args); + REVERSE_LIST(closure_fn_args); + CORD fn_type_code = compile_type(Type(FunctionType, .args=closure_fn_args, .ret=Match(fn_t, FunctionType)->ret)); + CORD closure = compile(env, call->fn); + CORD arg_code = compile_arguments(env, ast, type_args, call->args); + if (arg_code) arg_code = CORD_cat(arg_code, ", "); if (call->fn->tag == Var) { - return CORD_all("((", fn_type_code, ")", closure, ".fn)(", compile_arguments(env, ast, type_args, call->args), ")"); + return CORD_all("((", fn_type_code, ")", closure, ".fn)(", arg_code, closure, ".userdata)"); } else { return CORD_all("({ closure_t $closure = ", closure, "; ((", fn_type_code, ")$closure.fn)(", - compile_arguments(env, ast, type_args, call->args), "); })"); + arg_code, "$closure.userdata); })"); } } else { code_err(call->fn, "This is not a function, it's a %T", fn_t); @@ -1062,8 +1138,21 @@ CORD compile(env_t *env, ast_t *ast) } case Pass: return ";"; case Return: { + if (!env->fn_ctx) code_err(ast, "This return statement is not inside any function"); auto ret = Match(ast, Return)->value; - return ret ? CORD_asprintf("return %r;", compile(env, ret)) : "return;"; + assert(env->fn_ctx->return_type); + if (ret) { + type_t *ret_t = get_type(env, ret); + CORD value = compile(env, ret); + if (!promote(env, &value, ret_t, env->fn_ctx->return_type)) + code_err(ast, "This function expects a return value of type %T, but this return has type %T", + env->fn_ctx->return_type, ret_t); + return CORD_all("return ", value, ";"); + } else { + if (env->fn_ctx->return_type->tag != VoidType) + code_err(ast, "This function expects a return value of type %T", env->fn_ctx->return_type->tag); + return "return;"; + } } // Extern, case StructDef: { @@ -1251,10 +1340,10 @@ CORD compile(env_t *env, ast_t *ast) case TableType: { type_t *key_t = Match(container_t, TableType)->key_type; type_t *value_t = Match(container_t, TableType)->value_type; - if (!can_promote(index_t, key_t)) - code_err(indexing->index, "This value has type %T, but this table can only be index with keys of type %T", index_t, key_t); CORD table = compile_to_pointer_depth(env, indexing->indexed, 0, false); CORD key = compile(env, indexing->index); + if (!promote(env, &key, index_t, key_t)) + code_err(indexing->index, "This value has type %T, but this table can only be index with keys of type %T", index_t, key_t); file_t *f = indexing->index->file; return CORD_all("$Table_get(", table, ", ", compile_type(key_t), ", ", compile_type(value_t), ", ", key, ", ", compile_type_info(env, container_t), ", ", |
