Add ReturnType(ret) so we can more accurately track return values for

`if` statements
This commit is contained in:
Bruce Hill 2024-07-13 18:43:50 -04:00
parent 2e2f68e582
commit 39576466a7
5 changed files with 45 additions and 44 deletions

View File

@ -73,14 +73,16 @@ static table_t *get_closed_vars(env_t *env, ast_t *lambda_ast)
};
body_scope->fn_ctx = &fn_ctx;
body_scope->locals->fallback = env->globals;
type_t *ret_t = Match(Match(get_type(env, lambda_ast), ClosureType)->fn, FunctionType)->ret;
type_t *ret_t = get_type(body_scope, lambda->body);
if (ret_t->tag == ReturnType)
ret_t = Match(ret_t, ReturnType)->ret;
fn_ctx.return_type = ret_t;
// Find which variables are captured in the closure:
env_t *tmp_scope = fresh_scope(body_scope);
for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next) {
type_t *stmt_type = get_type(tmp_scope, stmt->ast);
if (stmt->next || (stmt_type->tag == VoidType || stmt_type->tag == AbortType || stmt->ast->tag == Return))
if (stmt->next || (stmt_type->tag == VoidType || stmt_type->tag == AbortType || get_type(body_scope, stmt->ast)->tag == ReturnType))
(void)compile_statement(tmp_scope, stmt->ast);
else
(void)compile(tmp_scope, stmt->ast);
@ -108,6 +110,7 @@ CORD compile_declaration(type_t *t, CORD name)
CORD compile_type(type_t *t)
{
switch (t->tag) {
case ReturnType: errx(1, "Shouldn't be compiling ReturnType to a type");
case AbortType: return "void";
case VoidType: return "void";
case MemoryType: return "void";
@ -366,7 +369,7 @@ CORD compile_statement(env_t *env, ast_t *ast)
compile(env, WrapAST(test->expr, TextLiteral, .cord=test->expr->file->filename)),
(int64_t)(test->expr->start - test->expr->file->text),
(int64_t)(test->expr->end - test->expr->file->text));
} else if (expr_t->tag == VoidType || expr_t->tag == AbortType) {
} else if (expr_t->tag == VoidType || expr_t->tag == AbortType || expr_t->tag == ReturnType) {
return CORD_asprintf(
"test(({ %r; NULL; }), NULL, NULL, %r, %ld, %ld);",
compile_statement(env, test->expr),
@ -391,7 +394,7 @@ CORD compile_statement(env_t *env, ast_t *ast)
return compile_statement(env, decl->value);
} else {
type_t *t = get_type(env, decl->value);
if (t->tag == AbortType || t->tag == VoidType)
if (t->tag == AbortType || t->tag == VoidType || t->tag == ReturnType)
code_err(ast, "You can't declare a variable with a %T value", t);
return CORD_all(compile_declaration(t, CORD_cat("$", Match(decl->var, Var)->name)), " = ", compile(env, decl->value), ";");
}
@ -598,8 +601,8 @@ CORD compile_statement(env_t *env, ast_t *ast)
};
body_scope->fn_ctx = &fn_ctx;
if (ret_t->tag != VoidType && ret_t->tag != AbortType && get_type(body_scope, fndef->body)->tag != AbortType)
type_t *body_type = get_type(body_scope, fndef->body);
if (ret_t->tag != VoidType && ret_t->tag != AbortType && body_type->tag != AbortType && body_type->tag != ReturnType)
code_err(ast, "This function can reach the end without returning a %T value!", ret_t);
CORD body = compile_statement(body_scope, fndef->body);
@ -1740,7 +1743,9 @@ CORD compile(env_t *env, ast_t *ast)
body_scope->fn_ctx = &fn_ctx;
body_scope->locals->fallback = env->globals;
body_scope->deferred = NULL;
type_t *ret_t = Match(Match(get_type(env, ast), ClosureType)->fn, FunctionType)->ret;
type_t *ret_t = get_type(body_scope, lambda->body);
if (ret_t->tag == ReturnType)
ret_t = Match(ret_t, ReturnType)->ret;
fn_ctx.return_type = ret_t;
CORD code = CORD_all("static ", compile_type(ret_t), " ", name, "(");
@ -1771,7 +1776,7 @@ CORD compile(env_t *env, ast_t *ast)
CORD body = CORD_EMPTY;
for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next) {
if (stmt->next || ret_t->tag == VoidType || ret_t->tag == AbortType || stmt->ast->tag == Return)
if (stmt->next || ret_t->tag == VoidType || ret_t->tag == AbortType || get_type(body_scope, stmt->ast)->tag == ReturnType)
body = CORD_all(body, compile_statement(body_scope, stmt->ast), "\n");
else
body = CORD_all(body, compile_statement(body_scope, FakeAST(Return, stmt->ast)), "\n");
@ -2000,10 +2005,10 @@ CORD compile(env_t *env, ast_t *ast)
type_t *true_type = get_type(env, if_->body);
type_t *false_type = get_type(env, if_->else_body);
if (true_type->tag == AbortType)
if (true_type->tag == AbortType || true_type->tag == ReturnType)
return CORD_all("({ if (", compile(env, if_->condition), ") ", compile_statement(env, if_->body),
"\n", compile(env, if_->else_body), "; })");
else if (false_type->tag == AbortType)
else if (false_type->tag == AbortType || false_type->tag == ReturnType)
return CORD_all("({ if (!(", compile(env, if_->condition), ")) ", compile_statement(env, if_->else_body),
"\n", compile(env, if_->body), "; })");
else
@ -2023,7 +2028,7 @@ CORD compile(env_t *env, ast_t *ast)
ast_t *empty = NULL;
if (reduction->fallback) {
type_t *fallback_type = get_type(scope, reduction->fallback);
if (fallback_type->tag == AbortType) {
if (fallback_type->tag == AbortType || fallback_type->tag == ReturnType) {
empty = reduction->fallback;
} else {
empty = FakeAST(Assign, .targets=new(ast_list_t, .ast=result), .values=new(ast_list_t, .ast=reduction->fallback));
@ -2434,7 +2439,7 @@ CORD compile_file(env_t *env, ast_t *ast)
const char *decl_name = Match(decl->var, Var)->name;
bool is_private = (decl_name[0] == '_');
type_t *t = get_type(env, decl->value);
if (t->tag == AbortType || t->tag == VoidType)
if (t->tag == AbortType || t->tag == VoidType || t->tag == ReturnType)
code_err(stmt->ast, "You can't declare a variable with a %T value", t);
if (!is_constant(env, decl->value))
code_err(decl->value, "This value is not a valid constant initializer.");
@ -2487,7 +2492,7 @@ CORD compile_statement_header(env_t *env, ast_t *ast)
}
type_t *t = get_type(env, decl->value);
assert(t->tag != ModuleType);
if (t->tag == AbortType || t->tag == VoidType)
if (t->tag == AbortType || t->tag == VoidType || t->tag == ReturnType)
code_err(ast, "You can't declare a variable with a %T value", t);
const char *decl_name = Match(decl->var, Var)->name;
bool is_private = (decl_name[0] == '_');

1
repl.c
View File

@ -98,6 +98,7 @@ const TypeInfo *type_to_type_info(type_t *t)
{
switch (t->tag) {
case AbortType: return &$Abort;
case ReturnType: errx(1, "Shouldn't be getting a typeinfo for ReturnType");
case VoidType: return &$Void;
case MemoryType: return &$Memory;
case BoolType: return &$Bool;

View File

@ -737,7 +737,11 @@ type_t *get_type(env_t *env, ast_t *ast)
case Use: {
return Type(ModuleType, Match(ast, Use)->name);
}
case Return: case Stop: case Skip: case PrintStatement: {
case Return: {
ast_t *val = Match(ast, Return)->value;
return Type(ReturnType, .ret=(val ? get_type(env, val) : Type(VoidType)));
}
case Stop: case Skip: case PrintStatement: {
return Type(AbortType);
}
case Pass: case Defer: return Type(VoidType);
@ -800,9 +804,9 @@ type_t *get_type(env_t *env, ast_t *ast)
case BINOP_AND: {
if (lhs_t->tag == BoolType && rhs_t->tag == BoolType) {
return lhs_t;
} else if (lhs_t->tag == BoolType && rhs_t->tag == AbortType) {
} else if (lhs_t->tag == BoolType && (rhs_t->tag == AbortType || rhs_t->tag == ReturnType)) {
return lhs_t;
} else if (rhs_t->tag == AbortType) {
} else if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) {
return lhs_t;
} else if (lhs_t->tag == PointerType && rhs_t->tag == PointerType) {
auto lhs_ptr = Match(lhs_t, PointerType);
@ -819,13 +823,13 @@ type_t *get_type(env_t *env, ast_t *ast)
case BINOP_OR: {
if (lhs_t->tag == BoolType && rhs_t->tag == BoolType) {
return lhs_t;
} else if (lhs_t->tag == BoolType && rhs_t->tag == AbortType) {
} else if (lhs_t->tag == BoolType && (rhs_t->tag == AbortType || rhs_t->tag == ReturnType)) {
return lhs_t;
} else if (lhs_t->tag == IntType && rhs_t->tag == IntType) {
return get_math_type(env, ast, lhs_t, rhs_t);
} else if (lhs_t->tag == PointerType) {
auto lhs_ptr = Match(lhs_t, PointerType);
if (rhs_t->tag == AbortType) {
if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) {
return Type(PointerType, .pointed=lhs_ptr->pointed, .is_optional=false, .is_readonly=lhs_ptr->is_readonly);
} else if (rhs_t->tag == PointerType) {
auto rhs_ptr = Match(rhs_t, PointerType);
@ -886,7 +890,7 @@ type_t *get_type(env_t *env, ast_t *ast)
if (!reduction->fallback)
return t;
type_t *fallback_t = get_type(env, reduction->fallback);
if (fallback_t->tag == AbortType)
if (fallback_t->tag == AbortType || fallback_t->tag == ReturnType)
return t;
else if (can_promote(fallback_t, t))
return t;
@ -922,26 +926,10 @@ type_t *get_type(env_t *env, ast_t *ast)
}
REVERSE_LIST(args);
type_t *ret;
auto block = Match(lambda->body, Block);
if (!block->statements) {
ret = Type(VoidType);
} else {
ast_list_t *last = block->statements;
if (!last)
return Type(VoidType);
while (last->next)
last = last->next;
env_t *block_env = fresh_scope(env);
for (ast_list_t *stmt = block->statements; stmt; stmt = stmt->next)
bind_statement(block_env, stmt->ast);
if (last->ast->tag == Return && Match(last->ast, Return)->value)
ret = get_type(block_env, Match(last->ast, Return)->value);
else
ret = get_type(block_env, last->ast);
}
type_t *ret = get_type(scope, lambda->body);
assert(ret);
if (ret->tag == ReturnType)
ret = Match(ret, ReturnType)->ret;
if (has_stack_memory(ret))
code_err(ast, "Functions can't return stack references because the reference may outlive its stack frame.");
@ -1118,7 +1106,7 @@ bool is_discardable(env_t *env, ast_t *ast)
default: break;
}
type_t *t = get_type(env, ast);
return (t->tag == VoidType || t->tag == AbortType);
return (t->tag == VoidType || t->tag == AbortType || t->tag == ReturnType);
}
type_t *get_file_type(env_t *env, const char *path)

11
types.c
View File

@ -13,6 +13,7 @@ CORD type_to_cord(type_t *t) {
switch (t->tag) {
case UnknownType: return "???";
case AbortType: return "Abort";
case ReturnType: return CORD_all("Return(", type_to_cord(Match(t, ReturnType)->ret), ")");
case VoidType: return "Void";
case MemoryType: return "Memory";
case BoolType: return "Bool";
@ -126,8 +127,10 @@ type_t *type_or_type(type_t *a, type_t *b)
if (!b) return a;
if (type_is_a(b, a)) return a;
if (type_is_a(a, b)) return b;
if (a->tag == AbortType) return non_optional(b);
if (b->tag == AbortType) return non_optional(a);
if (a->tag == ReturnType && b->tag == ReturnType)
return Type(ReturnType, .ret=type_or_type(Match(a, ReturnType)->ret, Match(b, ReturnType)->ret));
if (a->tag == AbortType || a->tag == ReturnType) return non_optional(b);
if (b->tag == AbortType || b->tag == ReturnType) return non_optional(a);
if ((a->tag == IntType || a->tag == NumType) && (b->tag == IntType || b->tag == NumType)) {
switch (compare_precision(a, b)) {
case NUM_PRECISION_EQUAL: case NUM_PRECISION_MORE: return a;
@ -396,7 +399,7 @@ type_t *replace_type(type_t *t, type_t *target, type_t *replacement)
size_t type_size(type_t *t)
{
switch (t->tag) {
case UnknownType: case AbortType: case VoidType: return 0;
case UnknownType: case AbortType: case ReturnType: case VoidType: return 0;
case MemoryType: errx(1, "Memory has undefined type size");
case BoolType: return sizeof(bool);
case CStringType: return sizeof(char*);
@ -448,7 +451,7 @@ size_t type_size(type_t *t)
size_t type_align(type_t *t)
{
switch (t->tag) {
case UnknownType: case AbortType: case VoidType: return 0;
case UnknownType: case AbortType: case ReturnType: case VoidType: return 0;
case MemoryType: errx(1, "Memory has undefined type alignment");
case BoolType: return __alignof__(bool);
case CStringType: return __alignof__(char*);

View File

@ -39,6 +39,7 @@ struct type_s {
enum {
UnknownType,
AbortType,
ReturnType,
VoidType,
MemoryType,
BoolType,
@ -60,6 +61,9 @@ struct type_s {
union {
struct {
} UnknownType, AbortType, VoidType, MemoryType, BoolType;
struct {
type_t *ret;
} ReturnType;
struct {
int64_t bits;
} IntType;