diff options
| author | Bruce Hill <bruce@bruce-hill.com> | 2024-09-06 14:15:55 -0400 |
|---|---|---|
| committer | Bruce Hill <bruce@bruce-hill.com> | 2024-09-06 14:15:55 -0400 |
| commit | 11fa4f548ca71baa96a9dba4aa9a24051de265d3 (patch) | |
| tree | 950a5142db5c5ec211229109d1fd7a1f04f14f4f /typecheck.c | |
| parent | 7239ec4083128cc002ad8bd16824824d71b20116 (diff) | |
Support 'when' statements as expressions
Diffstat (limited to 'typecheck.c')
| -rw-r--r-- | typecheck.c | 210 |
1 files changed, 121 insertions, 89 deletions
diff --git a/typecheck.c b/typecheck.c index 519d83d1..bde73ea5 100644 --- a/typecheck.c +++ b/typecheck.c @@ -433,6 +433,64 @@ type_t *get_method_type(env_t *env, ast_t *self, const char *name) return b->type; } +type_t *get_clause_type(env_t *env, type_t *subject_t, when_clause_t *clause) +{ + if (subject_t->tag == PointerType) { + if (!Match(subject_t, PointerType)->is_optional) + code_err(clause->body, "This %T pointer type is not optional, so this 'when' statement is tautological", subject_t); + + const char *tag_name = Match(clause->tag_name, Var)->name; + if (!streq(tag_name, "@")) + code_err(clause->tag_name, "'when' clauses on optional pointers only support @var, not tags like '%s'", tag_name); + + assert(clause->args); + env_t *scope = fresh_scope(env); + auto ptr = Match(subject_t, PointerType); + set_binding(scope, Match(clause->args->ast, Var)->name, + new(binding_t, .type=Type(PointerType, .pointed=ptr->pointed, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly))); + + return get_type(scope, clause->body); + } else { + assert(subject_t->tag == EnumType); + tag_t * const tags = Match(subject_t, EnumType)->tags; + + const char *tag_name = Match(clause->tag_name, Var)->name; + type_t *tag_type = NULL; + for (tag_t *tag = tags; tag; tag = tag->next) { + if (streq(tag->name, tag_name)) { + tag_type = tag->type; + break; + } + } + + if (!tag_type) + code_err(clause->tag_name, "There is no tag '%s' for the type %T", tag_name, subject_t); + + // Don't return early so we validate the tags + if (!clause->args) + return get_type(env, clause->body); + + env_t *scope = fresh_scope(env); + auto tag_struct = Match(tag_type, StructType); + if (!clause->args->next && tag_struct->fields && tag_struct->fields->next) { + set_binding(scope, Match(clause->args->ast, Var)->name, new(binding_t, .type=tag_type)); + } else { + ast_list_t *var = clause->args; + arg_t *field = tag_struct->fields; + while (var || field) { + if (!var) + code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, tag_name, field->name); + if (!field) + code_err(var->ast, "This is one more field than %T has", subject_t); + set_binding(scope, Match(var->ast, Var)->name, new(binding_t, .type=field->type)); + var = var->next; + field = field->next; + } + } + return get_type(scope, clause->body); + } +} + type_t *get_type(env_t *env, ast_t *ast) { if (!ast) return NULL; @@ -1145,12 +1203,7 @@ type_t *get_type(env_t *env, ast_t *ast) handled_at = true; assert(clause->args); - env_t *scope = fresh_scope(env); - auto ptr = Match(subject_t, PointerType); - set_binding(scope, Match(clause->args->ast, Var)->name, - new(binding_t, .type=Type(PointerType, .pointed=ptr->pointed, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly))); - - type_t *clause_type = get_type(scope, clause->body); + type_t *clause_type = get_clause_type(env, subject_t, clause); type_t *merged = type_or_type(overall_t, clause_type); if (!merged) code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T", @@ -1162,99 +1215,78 @@ type_t *get_type(env_t *env, ast_t *ast) if (!when->else_body) code_err(ast, "This 'when' statement doesn't handle null pointers"); return overall_t; - } - - if (subject_t->tag != EnumType) - code_err(when->subject, "'when' statements are only for enum types and optional pointers, not %T", subject_t); + } else if (subject_t->tag == EnumType) { + tag_t * const tags = Match(subject_t, EnumType)->tags; + + typedef struct match_s { + tag_t *tag; + bool handled; + struct match_s *next; + } match_t; + match_t *matches = NULL; + for (tag_t *tag = tags; tag; tag = tag->next) + matches = new(match_t, .tag=tag, .handled=false, .next=matches); - tag_t * const tags = Match(subject_t, EnumType)->tags; - - typedef struct match_s { - const char *name; - type_t *type; - bool handled; - struct match_s *next; - } match_t; - match_t *matches = NULL; - for (tag_t *tag = tags; tag; tag = tag->next) - matches = new(match_t, .name=tag->name, .type=tag->type, .next=matches); - - for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { - const char *tag_name = Match(clause->tag_name, Var)->name; - type_t *tag_type = NULL; - CORD valid_tags = CORD_EMPTY; - for (match_t *m = matches; m; m = m->next) { - if (streq(m->name, tag_name)) { - if (m->handled) - code_err(clause->tag_name, "This tag was already handled earlier"); - m->handled = true; - tag_type = m->type; - break; + for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { + const char *tag_name = Match(clause->tag_name, Var)->name; + CORD valid_tags = CORD_EMPTY; + for (match_t *m = matches; m; m = m->next) { + if (streq(m->tag->name, tag_name)) { + if (m->handled) + code_err(clause->tag_name, "This tag was already handled earlier"); + m->handled = true; + goto found_matching_tag; + } + if (valid_tags) valid_tags = CORD_cat(valid_tags, ", "); + valid_tags = CORD_cat(valid_tags, m->tag->name); } - if (valid_tags) valid_tags = CORD_cat(valid_tags, ", "); - valid_tags = CORD_cat(valid_tags, m->name); - } - if (!tag_type) code_err(clause->tag_name, "There is no tag '%s' for the type %T (valid tags: %s)", tag_name, subject_t, CORD_to_char_star(valid_tags)); + found_matching_tag:; + } - env_t *scope = env; - auto tag_struct = Match(tag_type, StructType); - if (clause->args && !clause->args->next && tag_struct->fields && tag_struct->fields->next) { - scope = fresh_scope(scope); - set_binding(scope, Match(clause->args->ast, Var)->name, new(binding_t, .type=tag_type)); - } else if (clause->args) { - scope = fresh_scope(scope); - ast_list_t *var = clause->args; - arg_t *field = tag_struct->fields; - while (var || field) { - if (!var) - code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, tag_name, field->name); - if (!field) - code_err(var->ast, "This is one more field than %T has", subject_t); - set_binding(scope, Match(var->ast, Var)->name, new(binding_t, .type=field->type)); - var = var->next; - field = field->next; - } + for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { + type_t *clause_type = get_clause_type(env, subject_t, clause); + type_t *merged = type_or_type(overall_t, clause_type); + if (!merged) + code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T", + clause_type, overall_t); + overall_t = merged; } - type_t *clause_type = get_type(scope, clause->body); - type_t *merged = type_or_type(overall_t, clause_type); - if (!merged) - code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T", - clause_type, overall_t); - overall_t = merged; - } - - if (when->else_body) { - bool any_unhandled = false; - for (match_t *m = matches; m; m = m->next) { - if (!m->handled) { - any_unhandled = true; - break; + + if (when->else_body) { + bool any_unhandled = false; + for (match_t *m = matches; m; m = m->next) { + if (!m->handled) { + any_unhandled = true; + break; + } + } + // HACK: `while when ...` is handled by the parser adding an implicit + // `else: stop`, which has an empty source code span. + if (!any_unhandled && when->else_body->end > when->else_body->start) + code_err(when->else_body, "This 'else' block will never run because every tag is handled"); + + type_t *else_t = get_type(env, when->else_body); + type_t *merged = type_or_type(overall_t, else_t); + if (!merged) + code_err(when->else_body, + "I was expecting this block to have a %T value (based on earlier clauses), but it actually has a %T value.", + overall_t, else_t); + return merged; + } else { + CORD unhandled = CORD_EMPTY; + for (match_t *m = matches; m; m = m->next) { + if (!m->handled) + unhandled = unhandled ? CORD_all(unhandled, ", ", m->tag->name) : m->tag->name; } + if (unhandled) + code_err(ast, "This 'when' statement doesn't handle the tags: %s", CORD_to_const_char_star(unhandled)); + return overall_t; } - // HACK: `while when ...` is handled by the parser adding an implicit - // `else: stop`, which has an empty source code span. - if (!any_unhandled && when->else_body->end > when->else_body->start) - code_err(when->else_body, "This 'else' block will never run because every tag is handled"); - - type_t *else_t = get_type(env, when->else_body); - type_t *merged = type_or_type(overall_t, else_t); - if (!merged) - code_err(when->else_body, - "I was expecting this block to have a %T value (based on earlier clauses), but it actually has a %T value.", - overall_t, else_t); - return merged; } else { - CORD unhandled = CORD_EMPTY; - for (match_t *m = matches; m; m = m->next) { - if (!m->handled) - unhandled = unhandled ? CORD_all(unhandled, ", ", m->name) : m->name; - } - if (unhandled) - code_err(ast, "This 'when' statement doesn't handle the tag(s): %s", CORD_to_const_char_star(unhandled)); - return overall_t; + code_err(when->subject, "'when' statements are only for enum types and optional pointers, not %T", subject_t); } } |
