diff options
| -rw-r--r-- | compile.c | 8 | ||||
| -rw-r--r-- | typecheck.c | 22 |
2 files changed, 25 insertions, 5 deletions
@@ -73,14 +73,14 @@ static table_t *get_closed_vars(env_t *env, ast_t *lambda_ast) }; body_scope->fn_ctx = &fn_ctx; body_scope->locals->fallback = env->globals; - type_t *ret_t = get_type(body_scope, lambda->body); + type_t *ret_t = Match(Match(get_type(env, lambda_ast), ClosureType)->fn, FunctionType)->ret; fn_ctx.return_type = ret_t; // Find which variables are captured in the closure: env_t *tmp_scope = fresh_scope(body_scope); for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next) { type_t *stmt_type = get_type(tmp_scope, stmt->ast); - if (stmt->next || (stmt_type->tag == VoidType || stmt_type->tag == AbortType)) + if (stmt->next || (stmt_type->tag == VoidType || stmt_type->tag == AbortType || stmt->ast->tag == Return)) (void)compile_statement(tmp_scope, stmt->ast); else (void)compile(tmp_scope, stmt->ast); @@ -1740,7 +1740,7 @@ CORD compile(env_t *env, ast_t *ast) body_scope->fn_ctx = &fn_ctx; body_scope->locals->fallback = env->globals; body_scope->deferred = NULL; - type_t *ret_t = get_type(body_scope, lambda->body); + type_t *ret_t = Match(Match(get_type(env, ast), ClosureType)->fn, FunctionType)->ret; fn_ctx.return_type = ret_t; CORD code = CORD_all("static ", compile_type(ret_t), " ", name, "("); @@ -1771,7 +1771,7 @@ CORD compile(env_t *env, ast_t *ast) CORD body = CORD_EMPTY; for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next) { - if (stmt->next || ret_t->tag == VoidType || ret_t->tag == AbortType) + if (stmt->next || ret_t->tag == VoidType || ret_t->tag == AbortType || stmt->ast->tag == Return) body = CORD_all(body, compile_statement(body_scope, stmt->ast), "\n"); else body = CORD_all(body, compile_statement(body_scope, FakeAST(Return, stmt->ast)), "\n"); diff --git a/typecheck.c b/typecheck.c index 0fda02e2..22001957 100644 --- a/typecheck.c +++ b/typecheck.c @@ -922,7 +922,27 @@ type_t *get_type(env_t *env, ast_t *ast) } REVERSE_LIST(args); - type_t *ret = get_type(scope, lambda->body); + type_t *ret; + auto block = Match(lambda->body, Block); + if (!block->statements) { + ret = Type(VoidType); + } else { + ast_list_t *last = block->statements; + if (!last) + return Type(VoidType); + while (last->next) + last = last->next; + + env_t *block_env = fresh_scope(env); + for (ast_list_t *stmt = block->statements; stmt; stmt = stmt->next) + bind_statement(block_env, stmt->ast); + + if (last->ast->tag == Return && Match(last->ast, Return)->value) + ret = get_type(block_env, Match(last->ast, Return)->value); + else + ret = get_type(block_env, last->ast); + } + if (has_stack_memory(ret)) code_err(ast, "Functions can't return stack references because the reference may outlive its stack frame."); return Type(ClosureType, Type(FunctionType, .args=args, .ret=ret)); |
