aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2024-07-04 18:27:08 -0400
committerBruce Hill <bruce@bruce-hill.com>2024-07-04 18:27:08 -0400
commitf4dee58f03774d033c55d890356cd93c3e2462fb (patch)
tree55ae39786762084625051a17b94dd99b5970f6ab
parent2c89f3385f0863c83267b50400832a81de07538c (diff)
Check for functions that don't return when they need to, as well as a
fix for 'when' statement typing
-rw-r--r--compile.c9
-rw-r--r--typecheck.c23
2 files changed, 25 insertions, 7 deletions
diff --git a/compile.c b/compile.c
index e9648c3a..09314c82 100644
--- a/compile.c
+++ b/compile.c
@@ -597,6 +597,10 @@ CORD compile_statement(env_t *env, ast_t *ast)
};
body_scope->fn_ctx = &fn_ctx;
+
+ if (ret_t->tag != VoidType && ret_t->tag != AbortType && get_type(body_scope, fndef->body)->tag != AbortType)
+ code_err(ast, "This function can reach the end without returning a %T value!", ret_t);
+
CORD body = compile_statement(body_scope, fndef->body);
if (CORD_fetch(body, 0) != '{')
body = CORD_asprintf("{\n%r\n}", body);
@@ -733,6 +737,9 @@ CORD compile_statement(env_t *env, ast_t *ast)
}
if (ret) {
+ if (env->fn_ctx->return_type->tag == VoidType || env->fn_ctx->return_type->tag == AbortType)
+ code_err(ast, "This function is not supposed to return any values, according to its type signature");
+
env = with_enum_scope(env, env->fn_ctx->return_type);
type_t *ret_t = get_type(env, ret);
CORD value = compile(env, ret);
@@ -742,7 +749,7 @@ CORD compile_statement(env_t *env, ast_t *ast)
return CORD_all(code, "return ", value, ";");
} else {
if (env->fn_ctx->return_type->tag != VoidType)
- code_err(ast, "This function expects a return value of type %T", env->fn_ctx->return_type->tag);
+ code_err(ast, "This function expects you to return a %T value", env->fn_ctx->return_type);
return CORD_all(code, "return;");
}
}
diff --git a/typecheck.c b/typecheck.c
index 86d66252..53946dc8 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -949,6 +949,7 @@ type_t *get_type(env_t *env, ast_t *ast)
case When: {
auto when = Match(ast, When);
type_t *subject_t = get_type(env, when->subject);
+ type_t *overall_t = NULL;
if (subject_t->tag == PointerType) {
if (!Match(subject_t, PointerType)->is_optional)
code_err(when->subject, "This %T pointer type is not optional, so this 'when' statement is tautological", subject_t);
@@ -961,12 +962,25 @@ type_t *get_type(env_t *env, ast_t *ast)
if (handled_at)
code_err(clause->tag_name, "This 'when' statement has already handled the case of non-null pointers!");
handled_at = true;
+
+ assert(clause->args);
+ env_t *scope = fresh_scope(env);
+ auto ptr = Match(subject_t, PointerType);
+ set_binding(scope, Match(clause->args->ast, Var)->name,
+ new(binding_t, .type=Type(PointerType, .pointed=ptr->pointed, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly)));
+
+ type_t *clause_type = get_type(scope, clause->body);
+ type_t *merged = type_or_type(overall_t, clause_type);
+ if (!merged)
+ code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T",
+ clause_type, overall_t);
+ overall_t = merged;
}
if (!handled_at)
code_err(ast, "This 'when' statement doesn't handle non-null pointers");
if (!when->else_body)
code_err(ast, "This 'when' statement doesn't handle null pointers");
- return Type(VoidType);
+ return overall_t;
}
if (subject_t->tag != EnumType)
@@ -984,7 +998,6 @@ type_t *get_type(env_t *env, ast_t *ast)
for (tag_t *tag = tags; tag; tag = tag->next)
matches = new(match_t, .name=tag->name, .type=tag->type, .next=matches);
- type_t *overall_t = NULL;
for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
const char *tag_name = Match(clause->tag_name, Var)->name;
type_t *tag_type = NULL;
@@ -1051,8 +1064,7 @@ type_t *get_type(env_t *env, ast_t *ast)
code_err(when->else_body,
"I was expecting this block to have a %T value (based on earlier clauses), but it actually has a %T value.",
overall_t, else_t);
- // return merged;
- return Type(VoidType);
+ return merged;
} else {
CORD unhandled = CORD_EMPTY;
for (match_t *m = matches; m; m = m->next) {
@@ -1061,8 +1073,7 @@ type_t *get_type(env_t *env, ast_t *ast)
}
if (unhandled)
code_err(ast, "This 'when' statement doesn't handle the tag(s): %s", CORD_to_const_char_star(unhandled));
- // return overall_t;
- return Type(VoidType);
+ return overall_t;
}
}