From 39576466a7bcc545c49a9f17b188cc307a0c9d9c Mon Sep 17 00:00:00 2001 From: Bruce Hill Date: Sat, 13 Jul 2024 18:43:50 -0400 Subject: [PATCH] Add ReturnType(ret) so we can more accurately track return values for `if` statements --- compile.c | 31 ++++++++++++++++++------------- repl.c | 1 + typecheck.c | 42 +++++++++++++++--------------------------- types.c | 11 +++++++---- types.h | 4 ++++ 5 files changed, 45 insertions(+), 44 deletions(-) diff --git a/compile.c b/compile.c index f2e156f..cc848e8 100644 --- a/compile.c +++ b/compile.c @@ -73,14 +73,16 @@ static table_t *get_closed_vars(env_t *env, ast_t *lambda_ast) }; body_scope->fn_ctx = &fn_ctx; body_scope->locals->fallback = env->globals; - type_t *ret_t = Match(Match(get_type(env, lambda_ast), ClosureType)->fn, FunctionType)->ret; + type_t *ret_t = get_type(body_scope, lambda->body); + if (ret_t->tag == ReturnType) + ret_t = Match(ret_t, ReturnType)->ret; fn_ctx.return_type = ret_t; // Find which variables are captured in the closure: env_t *tmp_scope = fresh_scope(body_scope); for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next) { type_t *stmt_type = get_type(tmp_scope, stmt->ast); - if (stmt->next || (stmt_type->tag == VoidType || stmt_type->tag == AbortType || stmt->ast->tag == Return)) + if (stmt->next || (stmt_type->tag == VoidType || stmt_type->tag == AbortType || get_type(body_scope, stmt->ast)->tag == ReturnType)) (void)compile_statement(tmp_scope, stmt->ast); else (void)compile(tmp_scope, stmt->ast); @@ -108,6 +110,7 @@ CORD compile_declaration(type_t *t, CORD name) CORD compile_type(type_t *t) { switch (t->tag) { + case ReturnType: errx(1, "Shouldn't be compiling ReturnType to a type"); case AbortType: return "void"; case VoidType: return "void"; case MemoryType: return "void"; @@ -366,7 +369,7 @@ CORD compile_statement(env_t *env, ast_t *ast) compile(env, WrapAST(test->expr, TextLiteral, .cord=test->expr->file->filename)), (int64_t)(test->expr->start - test->expr->file->text), (int64_t)(test->expr->end - test->expr->file->text)); - } else if (expr_t->tag == VoidType || expr_t->tag == AbortType) { + } else if (expr_t->tag == VoidType || expr_t->tag == AbortType || expr_t->tag == ReturnType) { return CORD_asprintf( "test(({ %r; NULL; }), NULL, NULL, %r, %ld, %ld);", compile_statement(env, test->expr), @@ -391,7 +394,7 @@ CORD compile_statement(env_t *env, ast_t *ast) return compile_statement(env, decl->value); } else { type_t *t = get_type(env, decl->value); - if (t->tag == AbortType || t->tag == VoidType) + if (t->tag == AbortType || t->tag == VoidType || t->tag == ReturnType) code_err(ast, "You can't declare a variable with a %T value", t); return CORD_all(compile_declaration(t, CORD_cat("$", Match(decl->var, Var)->name)), " = ", compile(env, decl->value), ";"); } @@ -598,8 +601,8 @@ CORD compile_statement(env_t *env, ast_t *ast) }; body_scope->fn_ctx = &fn_ctx; - - if (ret_t->tag != VoidType && ret_t->tag != AbortType && get_type(body_scope, fndef->body)->tag != AbortType) + type_t *body_type = get_type(body_scope, fndef->body); + if (ret_t->tag != VoidType && ret_t->tag != AbortType && body_type->tag != AbortType && body_type->tag != ReturnType) code_err(ast, "This function can reach the end without returning a %T value!", ret_t); CORD body = compile_statement(body_scope, fndef->body); @@ -1740,7 +1743,9 @@ CORD compile(env_t *env, ast_t *ast) body_scope->fn_ctx = &fn_ctx; body_scope->locals->fallback = env->globals; body_scope->deferred = NULL; - type_t *ret_t = Match(Match(get_type(env, ast), ClosureType)->fn, FunctionType)->ret; + type_t *ret_t = get_type(body_scope, lambda->body); + if (ret_t->tag == ReturnType) + ret_t = Match(ret_t, ReturnType)->ret; fn_ctx.return_type = ret_t; CORD code = CORD_all("static ", compile_type(ret_t), " ", name, "("); @@ -1771,7 +1776,7 @@ CORD compile(env_t *env, ast_t *ast) CORD body = CORD_EMPTY; for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next) { - if (stmt->next || ret_t->tag == VoidType || ret_t->tag == AbortType || stmt->ast->tag == Return) + if (stmt->next || ret_t->tag == VoidType || ret_t->tag == AbortType || get_type(body_scope, stmt->ast)->tag == ReturnType) body = CORD_all(body, compile_statement(body_scope, stmt->ast), "\n"); else body = CORD_all(body, compile_statement(body_scope, FakeAST(Return, stmt->ast)), "\n"); @@ -2000,10 +2005,10 @@ CORD compile(env_t *env, ast_t *ast) type_t *true_type = get_type(env, if_->body); type_t *false_type = get_type(env, if_->else_body); - if (true_type->tag == AbortType) + if (true_type->tag == AbortType || true_type->tag == ReturnType) return CORD_all("({ if (", compile(env, if_->condition), ") ", compile_statement(env, if_->body), "\n", compile(env, if_->else_body), "; })"); - else if (false_type->tag == AbortType) + else if (false_type->tag == AbortType || false_type->tag == ReturnType) return CORD_all("({ if (!(", compile(env, if_->condition), ")) ", compile_statement(env, if_->else_body), "\n", compile(env, if_->body), "; })"); else @@ -2023,7 +2028,7 @@ CORD compile(env_t *env, ast_t *ast) ast_t *empty = NULL; if (reduction->fallback) { type_t *fallback_type = get_type(scope, reduction->fallback); - if (fallback_type->tag == AbortType) { + if (fallback_type->tag == AbortType || fallback_type->tag == ReturnType) { empty = reduction->fallback; } else { empty = FakeAST(Assign, .targets=new(ast_list_t, .ast=result), .values=new(ast_list_t, .ast=reduction->fallback)); @@ -2434,7 +2439,7 @@ CORD compile_file(env_t *env, ast_t *ast) const char *decl_name = Match(decl->var, Var)->name; bool is_private = (decl_name[0] == '_'); type_t *t = get_type(env, decl->value); - if (t->tag == AbortType || t->tag == VoidType) + if (t->tag == AbortType || t->tag == VoidType || t->tag == ReturnType) code_err(stmt->ast, "You can't declare a variable with a %T value", t); if (!is_constant(env, decl->value)) code_err(decl->value, "This value is not a valid constant initializer."); @@ -2487,7 +2492,7 @@ CORD compile_statement_header(env_t *env, ast_t *ast) } type_t *t = get_type(env, decl->value); assert(t->tag != ModuleType); - if (t->tag == AbortType || t->tag == VoidType) + if (t->tag == AbortType || t->tag == VoidType || t->tag == ReturnType) code_err(ast, "You can't declare a variable with a %T value", t); const char *decl_name = Match(decl->var, Var)->name; bool is_private = (decl_name[0] == '_'); diff --git a/repl.c b/repl.c index f405145..e8cccfb 100644 --- a/repl.c +++ b/repl.c @@ -98,6 +98,7 @@ const TypeInfo *type_to_type_info(type_t *t) { switch (t->tag) { case AbortType: return &$Abort; + case ReturnType: errx(1, "Shouldn't be getting a typeinfo for ReturnType"); case VoidType: return &$Void; case MemoryType: return &$Memory; case BoolType: return &$Bool; diff --git a/typecheck.c b/typecheck.c index 2200195..c41c083 100644 --- a/typecheck.c +++ b/typecheck.c @@ -737,7 +737,11 @@ type_t *get_type(env_t *env, ast_t *ast) case Use: { return Type(ModuleType, Match(ast, Use)->name); } - case Return: case Stop: case Skip: case PrintStatement: { + case Return: { + ast_t *val = Match(ast, Return)->value; + return Type(ReturnType, .ret=(val ? get_type(env, val) : Type(VoidType))); + } + case Stop: case Skip: case PrintStatement: { return Type(AbortType); } case Pass: case Defer: return Type(VoidType); @@ -800,9 +804,9 @@ type_t *get_type(env_t *env, ast_t *ast) case BINOP_AND: { if (lhs_t->tag == BoolType && rhs_t->tag == BoolType) { return lhs_t; - } else if (lhs_t->tag == BoolType && rhs_t->tag == AbortType) { + } else if (lhs_t->tag == BoolType && (rhs_t->tag == AbortType || rhs_t->tag == ReturnType)) { return lhs_t; - } else if (rhs_t->tag == AbortType) { + } else if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) { return lhs_t; } else if (lhs_t->tag == PointerType && rhs_t->tag == PointerType) { auto lhs_ptr = Match(lhs_t, PointerType); @@ -819,13 +823,13 @@ type_t *get_type(env_t *env, ast_t *ast) case BINOP_OR: { if (lhs_t->tag == BoolType && rhs_t->tag == BoolType) { return lhs_t; - } else if (lhs_t->tag == BoolType && rhs_t->tag == AbortType) { + } else if (lhs_t->tag == BoolType && (rhs_t->tag == AbortType || rhs_t->tag == ReturnType)) { return lhs_t; } else if (lhs_t->tag == IntType && rhs_t->tag == IntType) { return get_math_type(env, ast, lhs_t, rhs_t); } else if (lhs_t->tag == PointerType) { auto lhs_ptr = Match(lhs_t, PointerType); - if (rhs_t->tag == AbortType) { + if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) { return Type(PointerType, .pointed=lhs_ptr->pointed, .is_optional=false, .is_readonly=lhs_ptr->is_readonly); } else if (rhs_t->tag == PointerType) { auto rhs_ptr = Match(rhs_t, PointerType); @@ -886,7 +890,7 @@ type_t *get_type(env_t *env, ast_t *ast) if (!reduction->fallback) return t; type_t *fallback_t = get_type(env, reduction->fallback); - if (fallback_t->tag == AbortType) + if (fallback_t->tag == AbortType || fallback_t->tag == ReturnType) return t; else if (can_promote(fallback_t, t)) return t; @@ -922,26 +926,10 @@ type_t *get_type(env_t *env, ast_t *ast) } REVERSE_LIST(args); - type_t *ret; - auto block = Match(lambda->body, Block); - if (!block->statements) { - ret = Type(VoidType); - } else { - ast_list_t *last = block->statements; - if (!last) - return Type(VoidType); - while (last->next) - last = last->next; - - env_t *block_env = fresh_scope(env); - for (ast_list_t *stmt = block->statements; stmt; stmt = stmt->next) - bind_statement(block_env, stmt->ast); - - if (last->ast->tag == Return && Match(last->ast, Return)->value) - ret = get_type(block_env, Match(last->ast, Return)->value); - else - ret = get_type(block_env, last->ast); - } + type_t *ret = get_type(scope, lambda->body); + assert(ret); + if (ret->tag == ReturnType) + ret = Match(ret, ReturnType)->ret; if (has_stack_memory(ret)) code_err(ast, "Functions can't return stack references because the reference may outlive its stack frame."); @@ -1118,7 +1106,7 @@ bool is_discardable(env_t *env, ast_t *ast) default: break; } type_t *t = get_type(env, ast); - return (t->tag == VoidType || t->tag == AbortType); + return (t->tag == VoidType || t->tag == AbortType || t->tag == ReturnType); } type_t *get_file_type(env_t *env, const char *path) diff --git a/types.c b/types.c index 9e887b2..85ae9a3 100644 --- a/types.c +++ b/types.c @@ -13,6 +13,7 @@ CORD type_to_cord(type_t *t) { switch (t->tag) { case UnknownType: return "???"; case AbortType: return "Abort"; + case ReturnType: return CORD_all("Return(", type_to_cord(Match(t, ReturnType)->ret), ")"); case VoidType: return "Void"; case MemoryType: return "Memory"; case BoolType: return "Bool"; @@ -126,8 +127,10 @@ type_t *type_or_type(type_t *a, type_t *b) if (!b) return a; if (type_is_a(b, a)) return a; if (type_is_a(a, b)) return b; - if (a->tag == AbortType) return non_optional(b); - if (b->tag == AbortType) return non_optional(a); + if (a->tag == ReturnType && b->tag == ReturnType) + return Type(ReturnType, .ret=type_or_type(Match(a, ReturnType)->ret, Match(b, ReturnType)->ret)); + if (a->tag == AbortType || a->tag == ReturnType) return non_optional(b); + if (b->tag == AbortType || b->tag == ReturnType) return non_optional(a); if ((a->tag == IntType || a->tag == NumType) && (b->tag == IntType || b->tag == NumType)) { switch (compare_precision(a, b)) { case NUM_PRECISION_EQUAL: case NUM_PRECISION_MORE: return a; @@ -396,7 +399,7 @@ type_t *replace_type(type_t *t, type_t *target, type_t *replacement) size_t type_size(type_t *t) { switch (t->tag) { - case UnknownType: case AbortType: case VoidType: return 0; + case UnknownType: case AbortType: case ReturnType: case VoidType: return 0; case MemoryType: errx(1, "Memory has undefined type size"); case BoolType: return sizeof(bool); case CStringType: return sizeof(char*); @@ -448,7 +451,7 @@ size_t type_size(type_t *t) size_t type_align(type_t *t) { switch (t->tag) { - case UnknownType: case AbortType: case VoidType: return 0; + case UnknownType: case AbortType: case ReturnType: case VoidType: return 0; case MemoryType: errx(1, "Memory has undefined type alignment"); case BoolType: return __alignof__(bool); case CStringType: return __alignof__(char*); diff --git a/types.h b/types.h index 6d667c3..3debb57 100644 --- a/types.h +++ b/types.h @@ -39,6 +39,7 @@ struct type_s { enum { UnknownType, AbortType, + ReturnType, VoidType, MemoryType, BoolType, @@ -60,6 +61,9 @@ struct type_s { union { struct { } UnknownType, AbortType, VoidType, MemoryType, BoolType; + struct { + type_t *ret; + } ReturnType; struct { int64_t bits; } IntType;