diff options
Diffstat (limited to 'typecheck.c')
| -rw-r--r-- | typecheck.c | 80 |
1 files changed, 80 insertions, 0 deletions
diff --git a/typecheck.c b/typecheck.c index ee4d024e..781b6ffc 100644 --- a/typecheck.c +++ b/typecheck.c @@ -534,6 +534,86 @@ type_t *get_type(env_t *env, ast_t *ast) } } + case When: { + auto when = Match(ast, When); + type_t *subject_t = get_type(env, when->subject); + if (subject_t->tag != EnumType) + code_err(when->subject, "'when' statements are only for enum types, not %T", subject_t); + + tag_t * const tags = Match(subject_t, EnumType)->tags; + + typedef struct match_s { + const char *name; + type_t *type; + bool handled; + struct match_s *next; + } match_t; + match_t *matches = NULL; + for (tag_t *tag = tags; tag; tag = tag->next) + matches = new(match_t, .name=tag->name, .type=tag->type, .next=matches); + + type_t *overall_t = NULL; + for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { + const char *tag_name = Match(clause->tag_name, Var)->name; + type_t *tag_type = NULL; + for (match_t *m = matches; m; m = m->next) { + if (streq(m->name, tag_name)) { + if (m->handled) + code_err(clause->tag_name, "This tag was already handled earlier"); + m->handled = true; + tag_type = m->type; + break; + } + } + + if (!tag_type) + code_err(clause->tag_name, "This is not a valid tag for the type %T", subject_t); + + env_t *scope = env; + if (clause->var) { + scope = fresh_scope(scope); + set_binding(scope, Match(clause->var, Var)->name, new(binding_t, .type=tag_type)); + } + type_t *clause_type = get_type(scope, clause->body); + type_t *merged = type_or_type(overall_t, clause_type); + if (!merged) + code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T", + clause_type, overall_t); + overall_t = merged; + } + + if (when->else_body) { + bool any_unhandled = false; + for (match_t *m = matches; m; m = m->next) { + if (!m->handled) { + any_unhandled = true; + break; + } + } + if (!any_unhandled) + code_err(when->else_body, "This 'else' block will never run because every tag is handled"); + + type_t *else_t = get_type(env, when->else_body); + type_t *merged = type_or_type(overall_t, else_t); + if (!merged) + code_err(when->else_body, + "I was expecting this block to have a %T value (based on earlier clauses), but it actually has a %T value.", + overall_t, else_t); + // return merged; + return Type(VoidType); + } else { + CORD unhandled = CORD_EMPTY; + for (match_t *m = matches; m; m = m->next) { + if (!m->handled) + unhandled = unhandled ? CORD_all(unhandled, ", ", m->name) : m->name; + } + if (unhandled) + code_err(ast, "This 'while' statement doesn't handle the tag(s): %s", CORD_to_const_char_star(unhandled)); + // return overall_t; + return Type(VoidType); + } + } + case While: case For: return Type(VoidType); case Unknown: code_err(ast, "I can't figure out the type of: %W", ast); |
