aboutsummaryrefslogtreecommitdiff
path: root/typecheck.c
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2025-03-06 18:37:08 -0500
committerBruce Hill <bruce@bruce-hill.com>2025-03-06 18:37:08 -0500
commit73e559fbe4182828742ac1b1d108bcdc42bc46d6 (patch)
tree22c4d6245114cf3ba72701dafe18b84ce72e3f66 /typecheck.c
parent09423f6d42d86c62beefa4607fba41e3698f1850 (diff)
Support 'when' for literal values with equality checking
Diffstat (limited to 'typecheck.c')
-rw-r--r--typecheck.c91
1 files changed, 61 insertions, 30 deletions
diff --git a/typecheck.c b/typecheck.c
index 99a9a78a..1ab20177 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -484,13 +484,18 @@ type_t *get_method_type(env_t *env, ast_t *self, const char *name)
return b->type;
}
-type_t *get_clause_type(env_t *env, type_t *subject_t, when_clause_t *clause)
+env_t *when_clause_scope(env_t *env, type_t *subject_t, when_clause_t *clause)
{
- assert(subject_t->tag == EnumType);
- tag_t * const tags = Match(subject_t, EnumType)->tags;
+ if (clause->pattern->tag == Var || subject_t->tag != EnumType)
+ return env;
+
+ if (clause->pattern->tag != FunctionCall || Match(clause->pattern, FunctionCall)->fn->tag != Var)
+ code_err(clause->pattern, "I only support variables and constructors for pattern matching %T types in a 'when' block", subject_t);
- const char *tag_name = Match(clause->tag_name, Var)->name;
+ auto fn = Match(clause->pattern, FunctionCall);
+ const char *tag_name = Match(fn->fn, Var)->name;
type_t *tag_type = NULL;
+ tag_t * const tags = Match(subject_t, EnumType)->tags;
for (tag_t *tag = tags; tag; tag = tag->next) {
if (streq(tag->name, tag_name)) {
tag_type = tag->type;
@@ -499,29 +504,38 @@ type_t *get_clause_type(env_t *env, type_t *subject_t, when_clause_t *clause)
}
if (!tag_type)
- code_err(clause->tag_name, "There is no tag '%s' for the type %T", tag_name, subject_t);
+ code_err(clause->pattern, "There is no tag '%s' for the type %T", tag_name, subject_t);
- // Don't return early so we validate the tags
- if (!clause->args)
- return get_type(env, clause->body);
+ if (!fn->args)
+ return env;
env_t *scope = fresh_scope(env);
auto tag_struct = Match(tag_type, StructType);
- if (!clause->args->next && tag_struct->fields && tag_struct->fields->next) {
- set_binding(scope, Match(clause->args->ast, Var)->name, tag_type, CORD_EMPTY);
- } else {
- ast_list_t *var = clause->args;
- arg_t *field = tag_struct->fields;
- while (var || field) {
- if (!var)
- code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, tag_name, field->name);
- if (!field)
- code_err(var->ast, "This is one more field than %T has", subject_t);
- set_binding(scope, Match(var->ast, Var)->name, field->type, CORD_EMPTY);
- var = var->next;
- field = field->next;
- }
- }
+ if (fn->args && !fn->args->next && tag_struct->fields && tag_struct->fields->next) {
+ if (fn->args->value->tag != Var)
+ code_err(fn->args->value, "I expected a variable here");
+ set_binding(scope, Match(fn->args->value, Var)->name, tag_type, CORD_EMPTY);
+ return scope;
+ }
+
+ arg_t *field = tag_struct->fields;
+ for (arg_ast_t *var = fn->args; var || field; var = var ? var->next : var) {
+ if (!var)
+ code_err(clause->pattern, "The field %T.%s.%s wasn't accounted for", subject_t, tag_name, field->name);
+ if (!field)
+ code_err(var->value, "This is one more field than %T has", subject_t);
+ if (var->value->tag != Var)
+ code_err(var->value, "I expected this to be a plain variable so I could bind it to a value");
+ if (!streq(Match(var->value, Var)->name, "_"))
+ set_binding(scope, Match(var->value, Var)->name, field->type, CORD_EMPTY);
+ field = field->next;
+ }
+ return scope;
+}
+
+type_t *get_clause_type(env_t *env, type_t *subject_t, when_clause_t *clause)
+{
+ env_t *scope = when_clause_scope(env, subject_t, clause);
return get_type(scope, clause->body);
}
@@ -1258,10 +1272,19 @@ type_t *get_type(env_t *env, ast_t *ast)
case When: {
auto when = Match(ast, When);
type_t *subject_t = get_type(env, when->subject);
- type_t *overall_t = NULL;
- if (subject_t->tag != EnumType)
- code_err(when->subject, "'when' statements are only for enum types and optional pointers, not %T", subject_t);
+ if (subject_t->tag != EnumType) {
+ type_t *t = NULL;
+ for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
+ t = type_or_type(t, get_type(env, clause->body));
+ }
+ if (when->else_body)
+ t = type_or_type(t, get_type(env, when->else_body));
+ else if (t->tag != OptionalType)
+ t = Type(OptionalType, .type=t);
+ return t;
+ }
+ type_t *overall_t = NULL;
tag_t * const tags = Match(subject_t, EnumType)->tags;
typedef struct match_s {
@@ -1274,12 +1297,19 @@ type_t *get_type(env_t *env, ast_t *ast)
matches = new(match_t, .tag=tag, .handled=false, .next=matches);
for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
- const char *tag_name = Match(clause->tag_name, Var)->name;
+ const char *tag_name;
+ if (clause->pattern->tag == Var)
+ tag_name = Match(clause->pattern, Var)->name;
+ else if (clause->pattern->tag == FunctionCall && Match(clause->pattern, FunctionCall)->fn->tag == Var)
+ tag_name = Match(Match(clause->pattern, FunctionCall)->fn, Var)->name;
+ else
+ code_err(clause->pattern, "This is not a valid pattern for a %T enum", subject_t);
+
CORD valid_tags = CORD_EMPTY;
for (match_t *m = matches; m; m = m->next) {
if (streq(m->tag->name, tag_name)) {
if (m->handled)
- code_err(clause->tag_name, "This tag was already handled earlier");
+ code_err(clause->pattern, "This tag was already handled earlier");
m->handled = true;
goto found_matching_tag;
}
@@ -1287,13 +1317,14 @@ type_t *get_type(env_t *env, ast_t *ast)
valid_tags = CORD_cat(valid_tags, m->tag->name);
}
- code_err(clause->tag_name, "There is no tag '%s' for the type %T (valid tags: %s)",
+ code_err(clause->pattern, "There is no tag '%s' for the type %T (valid tags: %s)",
tag_name, subject_t, CORD_to_char_star(valid_tags));
found_matching_tag:;
}
for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
- type_t *clause_type = get_clause_type(env, subject_t, clause);
+ env_t *clause_scope = when_clause_scope(env, subject_t, clause);
+ type_t *clause_type = get_type(clause_scope, clause->body);
type_t *merged = type_or_type(overall_t, clause_type);
if (!merged)
code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T",