Add ReturnType(ret) so we can more accurately track return values for
`if` statements
This commit is contained in:
parent
2e2f68e582
commit
39576466a7
31
compile.c
31
compile.c
@ -73,14 +73,16 @@ static table_t *get_closed_vars(env_t *env, ast_t *lambda_ast)
|
||||
};
|
||||
body_scope->fn_ctx = &fn_ctx;
|
||||
body_scope->locals->fallback = env->globals;
|
||||
type_t *ret_t = Match(Match(get_type(env, lambda_ast), ClosureType)->fn, FunctionType)->ret;
|
||||
type_t *ret_t = get_type(body_scope, lambda->body);
|
||||
if (ret_t->tag == ReturnType)
|
||||
ret_t = Match(ret_t, ReturnType)->ret;
|
||||
fn_ctx.return_type = ret_t;
|
||||
|
||||
// Find which variables are captured in the closure:
|
||||
env_t *tmp_scope = fresh_scope(body_scope);
|
||||
for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next) {
|
||||
type_t *stmt_type = get_type(tmp_scope, stmt->ast);
|
||||
if (stmt->next || (stmt_type->tag == VoidType || stmt_type->tag == AbortType || stmt->ast->tag == Return))
|
||||
if (stmt->next || (stmt_type->tag == VoidType || stmt_type->tag == AbortType || get_type(body_scope, stmt->ast)->tag == ReturnType))
|
||||
(void)compile_statement(tmp_scope, stmt->ast);
|
||||
else
|
||||
(void)compile(tmp_scope, stmt->ast);
|
||||
@ -108,6 +110,7 @@ CORD compile_declaration(type_t *t, CORD name)
|
||||
CORD compile_type(type_t *t)
|
||||
{
|
||||
switch (t->tag) {
|
||||
case ReturnType: errx(1, "Shouldn't be compiling ReturnType to a type");
|
||||
case AbortType: return "void";
|
||||
case VoidType: return "void";
|
||||
case MemoryType: return "void";
|
||||
@ -366,7 +369,7 @@ CORD compile_statement(env_t *env, ast_t *ast)
|
||||
compile(env, WrapAST(test->expr, TextLiteral, .cord=test->expr->file->filename)),
|
||||
(int64_t)(test->expr->start - test->expr->file->text),
|
||||
(int64_t)(test->expr->end - test->expr->file->text));
|
||||
} else if (expr_t->tag == VoidType || expr_t->tag == AbortType) {
|
||||
} else if (expr_t->tag == VoidType || expr_t->tag == AbortType || expr_t->tag == ReturnType) {
|
||||
return CORD_asprintf(
|
||||
"test(({ %r; NULL; }), NULL, NULL, %r, %ld, %ld);",
|
||||
compile_statement(env, test->expr),
|
||||
@ -391,7 +394,7 @@ CORD compile_statement(env_t *env, ast_t *ast)
|
||||
return compile_statement(env, decl->value);
|
||||
} else {
|
||||
type_t *t = get_type(env, decl->value);
|
||||
if (t->tag == AbortType || t->tag == VoidType)
|
||||
if (t->tag == AbortType || t->tag == VoidType || t->tag == ReturnType)
|
||||
code_err(ast, "You can't declare a variable with a %T value", t);
|
||||
return CORD_all(compile_declaration(t, CORD_cat("$", Match(decl->var, Var)->name)), " = ", compile(env, decl->value), ";");
|
||||
}
|
||||
@ -598,8 +601,8 @@ CORD compile_statement(env_t *env, ast_t *ast)
|
||||
};
|
||||
body_scope->fn_ctx = &fn_ctx;
|
||||
|
||||
|
||||
if (ret_t->tag != VoidType && ret_t->tag != AbortType && get_type(body_scope, fndef->body)->tag != AbortType)
|
||||
type_t *body_type = get_type(body_scope, fndef->body);
|
||||
if (ret_t->tag != VoidType && ret_t->tag != AbortType && body_type->tag != AbortType && body_type->tag != ReturnType)
|
||||
code_err(ast, "This function can reach the end without returning a %T value!", ret_t);
|
||||
|
||||
CORD body = compile_statement(body_scope, fndef->body);
|
||||
@ -1740,7 +1743,9 @@ CORD compile(env_t *env, ast_t *ast)
|
||||
body_scope->fn_ctx = &fn_ctx;
|
||||
body_scope->locals->fallback = env->globals;
|
||||
body_scope->deferred = NULL;
|
||||
type_t *ret_t = Match(Match(get_type(env, ast), ClosureType)->fn, FunctionType)->ret;
|
||||
type_t *ret_t = get_type(body_scope, lambda->body);
|
||||
if (ret_t->tag == ReturnType)
|
||||
ret_t = Match(ret_t, ReturnType)->ret;
|
||||
fn_ctx.return_type = ret_t;
|
||||
|
||||
CORD code = CORD_all("static ", compile_type(ret_t), " ", name, "(");
|
||||
@ -1771,7 +1776,7 @@ CORD compile(env_t *env, ast_t *ast)
|
||||
|
||||
CORD body = CORD_EMPTY;
|
||||
for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next) {
|
||||
if (stmt->next || ret_t->tag == VoidType || ret_t->tag == AbortType || stmt->ast->tag == Return)
|
||||
if (stmt->next || ret_t->tag == VoidType || ret_t->tag == AbortType || get_type(body_scope, stmt->ast)->tag == ReturnType)
|
||||
body = CORD_all(body, compile_statement(body_scope, stmt->ast), "\n");
|
||||
else
|
||||
body = CORD_all(body, compile_statement(body_scope, FakeAST(Return, stmt->ast)), "\n");
|
||||
@ -2000,10 +2005,10 @@ CORD compile(env_t *env, ast_t *ast)
|
||||
|
||||
type_t *true_type = get_type(env, if_->body);
|
||||
type_t *false_type = get_type(env, if_->else_body);
|
||||
if (true_type->tag == AbortType)
|
||||
if (true_type->tag == AbortType || true_type->tag == ReturnType)
|
||||
return CORD_all("({ if (", compile(env, if_->condition), ") ", compile_statement(env, if_->body),
|
||||
"\n", compile(env, if_->else_body), "; })");
|
||||
else if (false_type->tag == AbortType)
|
||||
else if (false_type->tag == AbortType || false_type->tag == ReturnType)
|
||||
return CORD_all("({ if (!(", compile(env, if_->condition), ")) ", compile_statement(env, if_->else_body),
|
||||
"\n", compile(env, if_->body), "; })");
|
||||
else
|
||||
@ -2023,7 +2028,7 @@ CORD compile(env_t *env, ast_t *ast)
|
||||
ast_t *empty = NULL;
|
||||
if (reduction->fallback) {
|
||||
type_t *fallback_type = get_type(scope, reduction->fallback);
|
||||
if (fallback_type->tag == AbortType) {
|
||||
if (fallback_type->tag == AbortType || fallback_type->tag == ReturnType) {
|
||||
empty = reduction->fallback;
|
||||
} else {
|
||||
empty = FakeAST(Assign, .targets=new(ast_list_t, .ast=result), .values=new(ast_list_t, .ast=reduction->fallback));
|
||||
@ -2434,7 +2439,7 @@ CORD compile_file(env_t *env, ast_t *ast)
|
||||
const char *decl_name = Match(decl->var, Var)->name;
|
||||
bool is_private = (decl_name[0] == '_');
|
||||
type_t *t = get_type(env, decl->value);
|
||||
if (t->tag == AbortType || t->tag == VoidType)
|
||||
if (t->tag == AbortType || t->tag == VoidType || t->tag == ReturnType)
|
||||
code_err(stmt->ast, "You can't declare a variable with a %T value", t);
|
||||
if (!is_constant(env, decl->value))
|
||||
code_err(decl->value, "This value is not a valid constant initializer.");
|
||||
@ -2487,7 +2492,7 @@ CORD compile_statement_header(env_t *env, ast_t *ast)
|
||||
}
|
||||
type_t *t = get_type(env, decl->value);
|
||||
assert(t->tag != ModuleType);
|
||||
if (t->tag == AbortType || t->tag == VoidType)
|
||||
if (t->tag == AbortType || t->tag == VoidType || t->tag == ReturnType)
|
||||
code_err(ast, "You can't declare a variable with a %T value", t);
|
||||
const char *decl_name = Match(decl->var, Var)->name;
|
||||
bool is_private = (decl_name[0] == '_');
|
||||
|
1
repl.c
1
repl.c
@ -98,6 +98,7 @@ const TypeInfo *type_to_type_info(type_t *t)
|
||||
{
|
||||
switch (t->tag) {
|
||||
case AbortType: return &$Abort;
|
||||
case ReturnType: errx(1, "Shouldn't be getting a typeinfo for ReturnType");
|
||||
case VoidType: return &$Void;
|
||||
case MemoryType: return &$Memory;
|
||||
case BoolType: return &$Bool;
|
||||
|
42
typecheck.c
42
typecheck.c
@ -737,7 +737,11 @@ type_t *get_type(env_t *env, ast_t *ast)
|
||||
case Use: {
|
||||
return Type(ModuleType, Match(ast, Use)->name);
|
||||
}
|
||||
case Return: case Stop: case Skip: case PrintStatement: {
|
||||
case Return: {
|
||||
ast_t *val = Match(ast, Return)->value;
|
||||
return Type(ReturnType, .ret=(val ? get_type(env, val) : Type(VoidType)));
|
||||
}
|
||||
case Stop: case Skip: case PrintStatement: {
|
||||
return Type(AbortType);
|
||||
}
|
||||
case Pass: case Defer: return Type(VoidType);
|
||||
@ -800,9 +804,9 @@ type_t *get_type(env_t *env, ast_t *ast)
|
||||
case BINOP_AND: {
|
||||
if (lhs_t->tag == BoolType && rhs_t->tag == BoolType) {
|
||||
return lhs_t;
|
||||
} else if (lhs_t->tag == BoolType && rhs_t->tag == AbortType) {
|
||||
} else if (lhs_t->tag == BoolType && (rhs_t->tag == AbortType || rhs_t->tag == ReturnType)) {
|
||||
return lhs_t;
|
||||
} else if (rhs_t->tag == AbortType) {
|
||||
} else if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) {
|
||||
return lhs_t;
|
||||
} else if (lhs_t->tag == PointerType && rhs_t->tag == PointerType) {
|
||||
auto lhs_ptr = Match(lhs_t, PointerType);
|
||||
@ -819,13 +823,13 @@ type_t *get_type(env_t *env, ast_t *ast)
|
||||
case BINOP_OR: {
|
||||
if (lhs_t->tag == BoolType && rhs_t->tag == BoolType) {
|
||||
return lhs_t;
|
||||
} else if (lhs_t->tag == BoolType && rhs_t->tag == AbortType) {
|
||||
} else if (lhs_t->tag == BoolType && (rhs_t->tag == AbortType || rhs_t->tag == ReturnType)) {
|
||||
return lhs_t;
|
||||
} else if (lhs_t->tag == IntType && rhs_t->tag == IntType) {
|
||||
return get_math_type(env, ast, lhs_t, rhs_t);
|
||||
} else if (lhs_t->tag == PointerType) {
|
||||
auto lhs_ptr = Match(lhs_t, PointerType);
|
||||
if (rhs_t->tag == AbortType) {
|
||||
if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) {
|
||||
return Type(PointerType, .pointed=lhs_ptr->pointed, .is_optional=false, .is_readonly=lhs_ptr->is_readonly);
|
||||
} else if (rhs_t->tag == PointerType) {
|
||||
auto rhs_ptr = Match(rhs_t, PointerType);
|
||||
@ -886,7 +890,7 @@ type_t *get_type(env_t *env, ast_t *ast)
|
||||
if (!reduction->fallback)
|
||||
return t;
|
||||
type_t *fallback_t = get_type(env, reduction->fallback);
|
||||
if (fallback_t->tag == AbortType)
|
||||
if (fallback_t->tag == AbortType || fallback_t->tag == ReturnType)
|
||||
return t;
|
||||
else if (can_promote(fallback_t, t))
|
||||
return t;
|
||||
@ -922,26 +926,10 @@ type_t *get_type(env_t *env, ast_t *ast)
|
||||
}
|
||||
REVERSE_LIST(args);
|
||||
|
||||
type_t *ret;
|
||||
auto block = Match(lambda->body, Block);
|
||||
if (!block->statements) {
|
||||
ret = Type(VoidType);
|
||||
} else {
|
||||
ast_list_t *last = block->statements;
|
||||
if (!last)
|
||||
return Type(VoidType);
|
||||
while (last->next)
|
||||
last = last->next;
|
||||
|
||||
env_t *block_env = fresh_scope(env);
|
||||
for (ast_list_t *stmt = block->statements; stmt; stmt = stmt->next)
|
||||
bind_statement(block_env, stmt->ast);
|
||||
|
||||
if (last->ast->tag == Return && Match(last->ast, Return)->value)
|
||||
ret = get_type(block_env, Match(last->ast, Return)->value);
|
||||
else
|
||||
ret = get_type(block_env, last->ast);
|
||||
}
|
||||
type_t *ret = get_type(scope, lambda->body);
|
||||
assert(ret);
|
||||
if (ret->tag == ReturnType)
|
||||
ret = Match(ret, ReturnType)->ret;
|
||||
|
||||
if (has_stack_memory(ret))
|
||||
code_err(ast, "Functions can't return stack references because the reference may outlive its stack frame.");
|
||||
@ -1118,7 +1106,7 @@ bool is_discardable(env_t *env, ast_t *ast)
|
||||
default: break;
|
||||
}
|
||||
type_t *t = get_type(env, ast);
|
||||
return (t->tag == VoidType || t->tag == AbortType);
|
||||
return (t->tag == VoidType || t->tag == AbortType || t->tag == ReturnType);
|
||||
}
|
||||
|
||||
type_t *get_file_type(env_t *env, const char *path)
|
||||
|
11
types.c
11
types.c
@ -13,6 +13,7 @@ CORD type_to_cord(type_t *t) {
|
||||
switch (t->tag) {
|
||||
case UnknownType: return "???";
|
||||
case AbortType: return "Abort";
|
||||
case ReturnType: return CORD_all("Return(", type_to_cord(Match(t, ReturnType)->ret), ")");
|
||||
case VoidType: return "Void";
|
||||
case MemoryType: return "Memory";
|
||||
case BoolType: return "Bool";
|
||||
@ -126,8 +127,10 @@ type_t *type_or_type(type_t *a, type_t *b)
|
||||
if (!b) return a;
|
||||
if (type_is_a(b, a)) return a;
|
||||
if (type_is_a(a, b)) return b;
|
||||
if (a->tag == AbortType) return non_optional(b);
|
||||
if (b->tag == AbortType) return non_optional(a);
|
||||
if (a->tag == ReturnType && b->tag == ReturnType)
|
||||
return Type(ReturnType, .ret=type_or_type(Match(a, ReturnType)->ret, Match(b, ReturnType)->ret));
|
||||
if (a->tag == AbortType || a->tag == ReturnType) return non_optional(b);
|
||||
if (b->tag == AbortType || b->tag == ReturnType) return non_optional(a);
|
||||
if ((a->tag == IntType || a->tag == NumType) && (b->tag == IntType || b->tag == NumType)) {
|
||||
switch (compare_precision(a, b)) {
|
||||
case NUM_PRECISION_EQUAL: case NUM_PRECISION_MORE: return a;
|
||||
@ -396,7 +399,7 @@ type_t *replace_type(type_t *t, type_t *target, type_t *replacement)
|
||||
size_t type_size(type_t *t)
|
||||
{
|
||||
switch (t->tag) {
|
||||
case UnknownType: case AbortType: case VoidType: return 0;
|
||||
case UnknownType: case AbortType: case ReturnType: case VoidType: return 0;
|
||||
case MemoryType: errx(1, "Memory has undefined type size");
|
||||
case BoolType: return sizeof(bool);
|
||||
case CStringType: return sizeof(char*);
|
||||
@ -448,7 +451,7 @@ size_t type_size(type_t *t)
|
||||
size_t type_align(type_t *t)
|
||||
{
|
||||
switch (t->tag) {
|
||||
case UnknownType: case AbortType: case VoidType: return 0;
|
||||
case UnknownType: case AbortType: case ReturnType: case VoidType: return 0;
|
||||
case MemoryType: errx(1, "Memory has undefined type alignment");
|
||||
case BoolType: return __alignof__(bool);
|
||||
case CStringType: return __alignof__(char*);
|
||||
|
4
types.h
4
types.h
@ -39,6 +39,7 @@ struct type_s {
|
||||
enum {
|
||||
UnknownType,
|
||||
AbortType,
|
||||
ReturnType,
|
||||
VoidType,
|
||||
MemoryType,
|
||||
BoolType,
|
||||
@ -60,6 +61,9 @@ struct type_s {
|
||||
union {
|
||||
struct {
|
||||
} UnknownType, AbortType, VoidType, MemoryType, BoolType;
|
||||
struct {
|
||||
type_t *ret;
|
||||
} ReturnType;
|
||||
struct {
|
||||
int64_t bits;
|
||||
} IntType;
|
||||
|
Loading…
Reference in New Issue
Block a user