aboutsummaryrefslogtreecommitdiff
path: root/typecheck.c
diff options
context:
space:
mode:
Diffstat (limited to 'typecheck.c')
-rw-r--r--typecheck.c23
1 files changed, 17 insertions, 6 deletions
diff --git a/typecheck.c b/typecheck.c
index 86d66252..53946dc8 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -949,6 +949,7 @@ type_t *get_type(env_t *env, ast_t *ast)
case When: {
auto when = Match(ast, When);
type_t *subject_t = get_type(env, when->subject);
+ type_t *overall_t = NULL;
if (subject_t->tag == PointerType) {
if (!Match(subject_t, PointerType)->is_optional)
code_err(when->subject, "This %T pointer type is not optional, so this 'when' statement is tautological", subject_t);
@@ -961,12 +962,25 @@ type_t *get_type(env_t *env, ast_t *ast)
if (handled_at)
code_err(clause->tag_name, "This 'when' statement has already handled the case of non-null pointers!");
handled_at = true;
+
+ assert(clause->args);
+ env_t *scope = fresh_scope(env);
+ auto ptr = Match(subject_t, PointerType);
+ set_binding(scope, Match(clause->args->ast, Var)->name,
+ new(binding_t, .type=Type(PointerType, .pointed=ptr->pointed, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly)));
+
+ type_t *clause_type = get_type(scope, clause->body);
+ type_t *merged = type_or_type(overall_t, clause_type);
+ if (!merged)
+ code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T",
+ clause_type, overall_t);
+ overall_t = merged;
}
if (!handled_at)
code_err(ast, "This 'when' statement doesn't handle non-null pointers");
if (!when->else_body)
code_err(ast, "This 'when' statement doesn't handle null pointers");
- return Type(VoidType);
+ return overall_t;
}
if (subject_t->tag != EnumType)
@@ -984,7 +998,6 @@ type_t *get_type(env_t *env, ast_t *ast)
for (tag_t *tag = tags; tag; tag = tag->next)
matches = new(match_t, .name=tag->name, .type=tag->type, .next=matches);
- type_t *overall_t = NULL;
for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
const char *tag_name = Match(clause->tag_name, Var)->name;
type_t *tag_type = NULL;
@@ -1051,8 +1064,7 @@ type_t *get_type(env_t *env, ast_t *ast)
code_err(when->else_body,
"I was expecting this block to have a %T value (based on earlier clauses), but it actually has a %T value.",
overall_t, else_t);
- // return merged;
- return Type(VoidType);
+ return merged;
} else {
CORD unhandled = CORD_EMPTY;
for (match_t *m = matches; m; m = m->next) {
@@ -1061,8 +1073,7 @@ type_t *get_type(env_t *env, ast_t *ast)
}
if (unhandled)
code_err(ast, "This 'when' statement doesn't handle the tag(s): %s", CORD_to_const_char_star(unhandled));
- // return overall_t;
- return Type(VoidType);
+ return overall_t;
}
}