From 2e2f68e5823cd3ad057993e0d4504107c6974fa4 Mon Sep 17 00:00:00 2001 From: Bruce Hill Date: Sat, 13 Jul 2024 18:26:41 -0400 Subject: Allow lambdas to have a return statement as the last statement --- compile.c | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'compile.c') diff --git a/compile.c b/compile.c index 7c69ace3..f2e156f7 100644 --- a/compile.c +++ b/compile.c @@ -73,14 +73,14 @@ static table_t *get_closed_vars(env_t *env, ast_t *lambda_ast) }; body_scope->fn_ctx = &fn_ctx; body_scope->locals->fallback = env->globals; - type_t *ret_t = get_type(body_scope, lambda->body); + type_t *ret_t = Match(Match(get_type(env, lambda_ast), ClosureType)->fn, FunctionType)->ret; fn_ctx.return_type = ret_t; // Find which variables are captured in the closure: env_t *tmp_scope = fresh_scope(body_scope); for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next) { type_t *stmt_type = get_type(tmp_scope, stmt->ast); - if (stmt->next || (stmt_type->tag == VoidType || stmt_type->tag == AbortType)) + if (stmt->next || (stmt_type->tag == VoidType || stmt_type->tag == AbortType || stmt->ast->tag == Return)) (void)compile_statement(tmp_scope, stmt->ast); else (void)compile(tmp_scope, stmt->ast); @@ -1740,7 +1740,7 @@ CORD compile(env_t *env, ast_t *ast) body_scope->fn_ctx = &fn_ctx; body_scope->locals->fallback = env->globals; body_scope->deferred = NULL; - type_t *ret_t = get_type(body_scope, lambda->body); + type_t *ret_t = Match(Match(get_type(env, ast), ClosureType)->fn, FunctionType)->ret; fn_ctx.return_type = ret_t; CORD code = CORD_all("static ", compile_type(ret_t), " ", name, "("); @@ -1771,7 +1771,7 @@ CORD compile(env_t *env, ast_t *ast) CORD body = CORD_EMPTY; for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next) { - if (stmt->next || ret_t->tag == VoidType || ret_t->tag == AbortType) + if (stmt->next || ret_t->tag == VoidType || ret_t->tag == AbortType || stmt->ast->tag == Return) body = CORD_all(body, compile_statement(body_scope, stmt->ast), "\n"); else body = CORD_all(body, compile_statement(body_scope, FakeAST(Return, stmt->ast)), "\n"); -- cgit v1.2.3