diff options
Diffstat (limited to 'typecheck.c')
| -rw-r--r-- | typecheck.c | 91 |
1 files changed, 61 insertions, 30 deletions
diff --git a/typecheck.c b/typecheck.c index 99a9a78a..1ab20177 100644 --- a/typecheck.c +++ b/typecheck.c @@ -484,13 +484,18 @@ type_t *get_method_type(env_t *env, ast_t *self, const char *name) return b->type; } -type_t *get_clause_type(env_t *env, type_t *subject_t, when_clause_t *clause) +env_t *when_clause_scope(env_t *env, type_t *subject_t, when_clause_t *clause) { - assert(subject_t->tag == EnumType); - tag_t * const tags = Match(subject_t, EnumType)->tags; + if (clause->pattern->tag == Var || subject_t->tag != EnumType) + return env; + + if (clause->pattern->tag != FunctionCall || Match(clause->pattern, FunctionCall)->fn->tag != Var) + code_err(clause->pattern, "I only support variables and constructors for pattern matching %T types in a 'when' block", subject_t); - const char *tag_name = Match(clause->tag_name, Var)->name; + auto fn = Match(clause->pattern, FunctionCall); + const char *tag_name = Match(fn->fn, Var)->name; type_t *tag_type = NULL; + tag_t * const tags = Match(subject_t, EnumType)->tags; for (tag_t *tag = tags; tag; tag = tag->next) { if (streq(tag->name, tag_name)) { tag_type = tag->type; @@ -499,29 +504,38 @@ type_t *get_clause_type(env_t *env, type_t *subject_t, when_clause_t *clause) } if (!tag_type) - code_err(clause->tag_name, "There is no tag '%s' for the type %T", tag_name, subject_t); + code_err(clause->pattern, "There is no tag '%s' for the type %T", tag_name, subject_t); - // Don't return early so we validate the tags - if (!clause->args) - return get_type(env, clause->body); + if (!fn->args) + return env; env_t *scope = fresh_scope(env); auto tag_struct = Match(tag_type, StructType); - if (!clause->args->next && tag_struct->fields && tag_struct->fields->next) { - set_binding(scope, Match(clause->args->ast, Var)->name, tag_type, CORD_EMPTY); - } else { - ast_list_t *var = clause->args; - arg_t *field = tag_struct->fields; - while (var || field) { - if (!var) - code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, tag_name, field->name); - if (!field) - code_err(var->ast, "This is one more field than %T has", subject_t); - set_binding(scope, Match(var->ast, Var)->name, field->type, CORD_EMPTY); - var = var->next; - field = field->next; - } - } + if (fn->args && !fn->args->next && tag_struct->fields && tag_struct->fields->next) { + if (fn->args->value->tag != Var) + code_err(fn->args->value, "I expected a variable here"); + set_binding(scope, Match(fn->args->value, Var)->name, tag_type, CORD_EMPTY); + return scope; + } + + arg_t *field = tag_struct->fields; + for (arg_ast_t *var = fn->args; var || field; var = var ? var->next : var) { + if (!var) + code_err(clause->pattern, "The field %T.%s.%s wasn't accounted for", subject_t, tag_name, field->name); + if (!field) + code_err(var->value, "This is one more field than %T has", subject_t); + if (var->value->tag != Var) + code_err(var->value, "I expected this to be a plain variable so I could bind it to a value"); + if (!streq(Match(var->value, Var)->name, "_")) + set_binding(scope, Match(var->value, Var)->name, field->type, CORD_EMPTY); + field = field->next; + } + return scope; +} + +type_t *get_clause_type(env_t *env, type_t *subject_t, when_clause_t *clause) +{ + env_t *scope = when_clause_scope(env, subject_t, clause); return get_type(scope, clause->body); } @@ -1258,10 +1272,19 @@ type_t *get_type(env_t *env, ast_t *ast) case When: { auto when = Match(ast, When); type_t *subject_t = get_type(env, when->subject); - type_t *overall_t = NULL; - if (subject_t->tag != EnumType) - code_err(when->subject, "'when' statements are only for enum types and optional pointers, not %T", subject_t); + if (subject_t->tag != EnumType) { + type_t *t = NULL; + for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { + t = type_or_type(t, get_type(env, clause->body)); + } + if (when->else_body) + t = type_or_type(t, get_type(env, when->else_body)); + else if (t->tag != OptionalType) + t = Type(OptionalType, .type=t); + return t; + } + type_t *overall_t = NULL; tag_t * const tags = Match(subject_t, EnumType)->tags; typedef struct match_s { @@ -1274,12 +1297,19 @@ type_t *get_type(env_t *env, ast_t *ast) matches = new(match_t, .tag=tag, .handled=false, .next=matches); for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { - const char *tag_name = Match(clause->tag_name, Var)->name; + const char *tag_name; + if (clause->pattern->tag == Var) + tag_name = Match(clause->pattern, Var)->name; + else if (clause->pattern->tag == FunctionCall && Match(clause->pattern, FunctionCall)->fn->tag == Var) + tag_name = Match(Match(clause->pattern, FunctionCall)->fn, Var)->name; + else + code_err(clause->pattern, "This is not a valid pattern for a %T enum", subject_t); + CORD valid_tags = CORD_EMPTY; for (match_t *m = matches; m; m = m->next) { if (streq(m->tag->name, tag_name)) { if (m->handled) - code_err(clause->tag_name, "This tag was already handled earlier"); + code_err(clause->pattern, "This tag was already handled earlier"); m->handled = true; goto found_matching_tag; } @@ -1287,13 +1317,14 @@ type_t *get_type(env_t *env, ast_t *ast) valid_tags = CORD_cat(valid_tags, m->tag->name); } - code_err(clause->tag_name, "There is no tag '%s' for the type %T (valid tags: %s)", + code_err(clause->pattern, "There is no tag '%s' for the type %T (valid tags: %s)", tag_name, subject_t, CORD_to_char_star(valid_tags)); found_matching_tag:; } for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { - type_t *clause_type = get_clause_type(env, subject_t, clause); + env_t *clause_scope = when_clause_scope(env, subject_t, clause); + type_t *clause_type = get_type(clause_scope, clause->body); type_t *merged = type_or_type(overall_t, clause_type); if (!merged) code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T", |
