aboutsummaryrefslogtreecommitdiff
path: root/typecheck.c
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2024-09-06 14:15:55 -0400
committerBruce Hill <bruce@bruce-hill.com>2024-09-06 14:15:55 -0400
commit11fa4f548ca71baa96a9dba4aa9a24051de265d3 (patch)
tree950a5142db5c5ec211229109d1fd7a1f04f14f4f /typecheck.c
parent7239ec4083128cc002ad8bd16824824d71b20116 (diff)
Support 'when' statements as expressions
Diffstat (limited to 'typecheck.c')
-rw-r--r--typecheck.c210
1 files changed, 121 insertions, 89 deletions
diff --git a/typecheck.c b/typecheck.c
index 519d83d1..bde73ea5 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -433,6 +433,64 @@ type_t *get_method_type(env_t *env, ast_t *self, const char *name)
return b->type;
}
+type_t *get_clause_type(env_t *env, type_t *subject_t, when_clause_t *clause)
+{
+ if (subject_t->tag == PointerType) {
+ if (!Match(subject_t, PointerType)->is_optional)
+ code_err(clause->body, "This %T pointer type is not optional, so this 'when' statement is tautological", subject_t);
+
+ const char *tag_name = Match(clause->tag_name, Var)->name;
+ if (!streq(tag_name, "@"))
+ code_err(clause->tag_name, "'when' clauses on optional pointers only support @var, not tags like '%s'", tag_name);
+
+ assert(clause->args);
+ env_t *scope = fresh_scope(env);
+ auto ptr = Match(subject_t, PointerType);
+ set_binding(scope, Match(clause->args->ast, Var)->name,
+ new(binding_t, .type=Type(PointerType, .pointed=ptr->pointed, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly)));
+
+ return get_type(scope, clause->body);
+ } else {
+ assert(subject_t->tag == EnumType);
+ tag_t * const tags = Match(subject_t, EnumType)->tags;
+
+ const char *tag_name = Match(clause->tag_name, Var)->name;
+ type_t *tag_type = NULL;
+ for (tag_t *tag = tags; tag; tag = tag->next) {
+ if (streq(tag->name, tag_name)) {
+ tag_type = tag->type;
+ break;
+ }
+ }
+
+ if (!tag_type)
+ code_err(clause->tag_name, "There is no tag '%s' for the type %T", tag_name, subject_t);
+
+ // Don't return early so we validate the tags
+ if (!clause->args)
+ return get_type(env, clause->body);
+
+ env_t *scope = fresh_scope(env);
+ auto tag_struct = Match(tag_type, StructType);
+ if (!clause->args->next && tag_struct->fields && tag_struct->fields->next) {
+ set_binding(scope, Match(clause->args->ast, Var)->name, new(binding_t, .type=tag_type));
+ } else {
+ ast_list_t *var = clause->args;
+ arg_t *field = tag_struct->fields;
+ while (var || field) {
+ if (!var)
+ code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, tag_name, field->name);
+ if (!field)
+ code_err(var->ast, "This is one more field than %T has", subject_t);
+ set_binding(scope, Match(var->ast, Var)->name, new(binding_t, .type=field->type));
+ var = var->next;
+ field = field->next;
+ }
+ }
+ return get_type(scope, clause->body);
+ }
+}
+
type_t *get_type(env_t *env, ast_t *ast)
{
if (!ast) return NULL;
@@ -1145,12 +1203,7 @@ type_t *get_type(env_t *env, ast_t *ast)
handled_at = true;
assert(clause->args);
- env_t *scope = fresh_scope(env);
- auto ptr = Match(subject_t, PointerType);
- set_binding(scope, Match(clause->args->ast, Var)->name,
- new(binding_t, .type=Type(PointerType, .pointed=ptr->pointed, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly)));
-
- type_t *clause_type = get_type(scope, clause->body);
+ type_t *clause_type = get_clause_type(env, subject_t, clause);
type_t *merged = type_or_type(overall_t, clause_type);
if (!merged)
code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T",
@@ -1162,99 +1215,78 @@ type_t *get_type(env_t *env, ast_t *ast)
if (!when->else_body)
code_err(ast, "This 'when' statement doesn't handle null pointers");
return overall_t;
- }
-
- if (subject_t->tag != EnumType)
- code_err(when->subject, "'when' statements are only for enum types and optional pointers, not %T", subject_t);
+ } else if (subject_t->tag == EnumType) {
+ tag_t * const tags = Match(subject_t, EnumType)->tags;
+
+ typedef struct match_s {
+ tag_t *tag;
+ bool handled;
+ struct match_s *next;
+ } match_t;
+ match_t *matches = NULL;
+ for (tag_t *tag = tags; tag; tag = tag->next)
+ matches = new(match_t, .tag=tag, .handled=false, .next=matches);
- tag_t * const tags = Match(subject_t, EnumType)->tags;
-
- typedef struct match_s {
- const char *name;
- type_t *type;
- bool handled;
- struct match_s *next;
- } match_t;
- match_t *matches = NULL;
- for (tag_t *tag = tags; tag; tag = tag->next)
- matches = new(match_t, .name=tag->name, .type=tag->type, .next=matches);
-
- for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
- const char *tag_name = Match(clause->tag_name, Var)->name;
- type_t *tag_type = NULL;
- CORD valid_tags = CORD_EMPTY;
- for (match_t *m = matches; m; m = m->next) {
- if (streq(m->name, tag_name)) {
- if (m->handled)
- code_err(clause->tag_name, "This tag was already handled earlier");
- m->handled = true;
- tag_type = m->type;
- break;
+ for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
+ const char *tag_name = Match(clause->tag_name, Var)->name;
+ CORD valid_tags = CORD_EMPTY;
+ for (match_t *m = matches; m; m = m->next) {
+ if (streq(m->tag->name, tag_name)) {
+ if (m->handled)
+ code_err(clause->tag_name, "This tag was already handled earlier");
+ m->handled = true;
+ goto found_matching_tag;
+ }
+ if (valid_tags) valid_tags = CORD_cat(valid_tags, ", ");
+ valid_tags = CORD_cat(valid_tags, m->tag->name);
}
- if (valid_tags) valid_tags = CORD_cat(valid_tags, ", ");
- valid_tags = CORD_cat(valid_tags, m->name);
- }
- if (!tag_type)
code_err(clause->tag_name, "There is no tag '%s' for the type %T (valid tags: %s)",
tag_name, subject_t, CORD_to_char_star(valid_tags));
+ found_matching_tag:;
+ }
- env_t *scope = env;
- auto tag_struct = Match(tag_type, StructType);
- if (clause->args && !clause->args->next && tag_struct->fields && tag_struct->fields->next) {
- scope = fresh_scope(scope);
- set_binding(scope, Match(clause->args->ast, Var)->name, new(binding_t, .type=tag_type));
- } else if (clause->args) {
- scope = fresh_scope(scope);
- ast_list_t *var = clause->args;
- arg_t *field = tag_struct->fields;
- while (var || field) {
- if (!var)
- code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, tag_name, field->name);
- if (!field)
- code_err(var->ast, "This is one more field than %T has", subject_t);
- set_binding(scope, Match(var->ast, Var)->name, new(binding_t, .type=field->type));
- var = var->next;
- field = field->next;
- }
+ for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
+ type_t *clause_type = get_clause_type(env, subject_t, clause);
+ type_t *merged = type_or_type(overall_t, clause_type);
+ if (!merged)
+ code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T",
+ clause_type, overall_t);
+ overall_t = merged;
}
- type_t *clause_type = get_type(scope, clause->body);
- type_t *merged = type_or_type(overall_t, clause_type);
- if (!merged)
- code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T",
- clause_type, overall_t);
- overall_t = merged;
- }
-
- if (when->else_body) {
- bool any_unhandled = false;
- for (match_t *m = matches; m; m = m->next) {
- if (!m->handled) {
- any_unhandled = true;
- break;
+
+ if (when->else_body) {
+ bool any_unhandled = false;
+ for (match_t *m = matches; m; m = m->next) {
+ if (!m->handled) {
+ any_unhandled = true;
+ break;
+ }
+ }
+ // HACK: `while when ...` is handled by the parser adding an implicit
+ // `else: stop`, which has an empty source code span.
+ if (!any_unhandled && when->else_body->end > when->else_body->start)
+ code_err(when->else_body, "This 'else' block will never run because every tag is handled");
+
+ type_t *else_t = get_type(env, when->else_body);
+ type_t *merged = type_or_type(overall_t, else_t);
+ if (!merged)
+ code_err(when->else_body,
+ "I was expecting this block to have a %T value (based on earlier clauses), but it actually has a %T value.",
+ overall_t, else_t);
+ return merged;
+ } else {
+ CORD unhandled = CORD_EMPTY;
+ for (match_t *m = matches; m; m = m->next) {
+ if (!m->handled)
+ unhandled = unhandled ? CORD_all(unhandled, ", ", m->tag->name) : m->tag->name;
}
+ if (unhandled)
+ code_err(ast, "This 'when' statement doesn't handle the tags: %s", CORD_to_const_char_star(unhandled));
+ return overall_t;
}
- // HACK: `while when ...` is handled by the parser adding an implicit
- // `else: stop`, which has an empty source code span.
- if (!any_unhandled && when->else_body->end > when->else_body->start)
- code_err(when->else_body, "This 'else' block will never run because every tag is handled");
-
- type_t *else_t = get_type(env, when->else_body);
- type_t *merged = type_or_type(overall_t, else_t);
- if (!merged)
- code_err(when->else_body,
- "I was expecting this block to have a %T value (based on earlier clauses), but it actually has a %T value.",
- overall_t, else_t);
- return merged;
} else {
- CORD unhandled = CORD_EMPTY;
- for (match_t *m = matches; m; m = m->next) {
- if (!m->handled)
- unhandled = unhandled ? CORD_all(unhandled, ", ", m->name) : m->name;
- }
- if (unhandled)
- code_err(ast, "This 'when' statement doesn't handle the tag(s): %s", CORD_to_const_char_star(unhandled));
- return overall_t;
+ code_err(when->subject, "'when' statements are only for enum types and optional pointers, not %T", subject_t);
}
}