aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compile.c37
-rw-r--r--test/enums.tm8
-rw-r--r--typecheck.c210
-rw-r--r--typecheck.h1
4 files changed, 165 insertions, 91 deletions
diff --git a/compile.c b/compile.c
index 0be3aa89..9de6c979 100644
--- a/compile.c
+++ b/compile.c
@@ -2657,8 +2657,41 @@ CORD compile(env_t *env, ast_t *ast)
code_err(call->fn, "This is not a function, it's a %T", fn_t);
}
}
- case When:
- code_err(ast, "'when' expressions are not yet implemented");
+ case When: {
+ auto original = Match(ast, When);
+ ast_t *when_var = WrapAST(ast, Var, .name="when");
+ when_clause_t *new_clauses = NULL;
+ type_t *subject_t = get_type(env, original->subject);
+ for (when_clause_t *clause = original->clauses; clause; clause = clause->next) {
+ type_t *clause_type = get_clause_type(env, subject_t, clause);
+ if (clause_type->tag == AbortType || clause_type->tag == ReturnType) {
+ new_clauses = new(when_clause_t, .tag_name=clause->tag_name, .args=clause->args, .body=clause->body, .next=new_clauses);
+ } else {
+ ast_t *assign = WrapAST(clause->body, Assign,
+ .targets=new(ast_list_t, .ast=when_var),
+ .values=new(ast_list_t, .ast=clause->body));
+ new_clauses = new(when_clause_t, .tag_name=clause->tag_name, .args=clause->args, .body=assign, .next=new_clauses);
+ }
+ }
+ REVERSE_LIST(new_clauses);
+ ast_t *else_body = original->else_body;
+ if (else_body) {
+ type_t *clause_type = get_type(env, else_body);
+ if (clause_type->tag != AbortType && clause_type->tag != ReturnType) {
+ else_body = WrapAST(else_body, Assign,
+ .targets=new(ast_list_t, .ast=when_var),
+ .values=new(ast_list_t, .ast=else_body));
+ }
+ }
+
+ type_t *t = get_type(env, ast);
+ env_t *when_env = fresh_scope(env);
+ set_binding(when_env, "when", new(binding_t, .type=t, .code="when"));
+ return CORD_all(
+ "({ ", compile_declaration(t, "when"), ";\n",
+ compile_statement(when_env, WrapAST(ast, When, .subject=original->subject, .clauses=new_clauses, .else_body=else_body)),
+ "when; })");
+ }
case If: {
auto if_ = Match(ast, If);
if (!if_->else_body)
diff --git a/test/enums.tm b/test/enums.tm
index f4af342d..98811408 100644
--- a/test/enums.tm
+++ b/test/enums.tm
@@ -66,4 +66,12 @@ func main():
while when cases[i] is One(x):
>> x
i += 1
+
+ >> expr := when cases[1] is One(y):
+ y + 1
+ else:
+ -1
+ = 2
+
+
diff --git a/typecheck.c b/typecheck.c
index 519d83d1..bde73ea5 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -433,6 +433,64 @@ type_t *get_method_type(env_t *env, ast_t *self, const char *name)
return b->type;
}
+type_t *get_clause_type(env_t *env, type_t *subject_t, when_clause_t *clause)
+{
+ if (subject_t->tag == PointerType) {
+ if (!Match(subject_t, PointerType)->is_optional)
+ code_err(clause->body, "This %T pointer type is not optional, so this 'when' statement is tautological", subject_t);
+
+ const char *tag_name = Match(clause->tag_name, Var)->name;
+ if (!streq(tag_name, "@"))
+ code_err(clause->tag_name, "'when' clauses on optional pointers only support @var, not tags like '%s'", tag_name);
+
+ assert(clause->args);
+ env_t *scope = fresh_scope(env);
+ auto ptr = Match(subject_t, PointerType);
+ set_binding(scope, Match(clause->args->ast, Var)->name,
+ new(binding_t, .type=Type(PointerType, .pointed=ptr->pointed, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly)));
+
+ return get_type(scope, clause->body);
+ } else {
+ assert(subject_t->tag == EnumType);
+ tag_t * const tags = Match(subject_t, EnumType)->tags;
+
+ const char *tag_name = Match(clause->tag_name, Var)->name;
+ type_t *tag_type = NULL;
+ for (tag_t *tag = tags; tag; tag = tag->next) {
+ if (streq(tag->name, tag_name)) {
+ tag_type = tag->type;
+ break;
+ }
+ }
+
+ if (!tag_type)
+ code_err(clause->tag_name, "There is no tag '%s' for the type %T", tag_name, subject_t);
+
+ // Don't return early so we validate the tags
+ if (!clause->args)
+ return get_type(env, clause->body);
+
+ env_t *scope = fresh_scope(env);
+ auto tag_struct = Match(tag_type, StructType);
+ if (!clause->args->next && tag_struct->fields && tag_struct->fields->next) {
+ set_binding(scope, Match(clause->args->ast, Var)->name, new(binding_t, .type=tag_type));
+ } else {
+ ast_list_t *var = clause->args;
+ arg_t *field = tag_struct->fields;
+ while (var || field) {
+ if (!var)
+ code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, tag_name, field->name);
+ if (!field)
+ code_err(var->ast, "This is one more field than %T has", subject_t);
+ set_binding(scope, Match(var->ast, Var)->name, new(binding_t, .type=field->type));
+ var = var->next;
+ field = field->next;
+ }
+ }
+ return get_type(scope, clause->body);
+ }
+}
+
type_t *get_type(env_t *env, ast_t *ast)
{
if (!ast) return NULL;
@@ -1145,12 +1203,7 @@ type_t *get_type(env_t *env, ast_t *ast)
handled_at = true;
assert(clause->args);
- env_t *scope = fresh_scope(env);
- auto ptr = Match(subject_t, PointerType);
- set_binding(scope, Match(clause->args->ast, Var)->name,
- new(binding_t, .type=Type(PointerType, .pointed=ptr->pointed, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly)));
-
- type_t *clause_type = get_type(scope, clause->body);
+ type_t *clause_type = get_clause_type(env, subject_t, clause);
type_t *merged = type_or_type(overall_t, clause_type);
if (!merged)
code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T",
@@ -1162,99 +1215,78 @@ type_t *get_type(env_t *env, ast_t *ast)
if (!when->else_body)
code_err(ast, "This 'when' statement doesn't handle null pointers");
return overall_t;
- }
-
- if (subject_t->tag != EnumType)
- code_err(when->subject, "'when' statements are only for enum types and optional pointers, not %T", subject_t);
+ } else if (subject_t->tag == EnumType) {
+ tag_t * const tags = Match(subject_t, EnumType)->tags;
+
+ typedef struct match_s {
+ tag_t *tag;
+ bool handled;
+ struct match_s *next;
+ } match_t;
+ match_t *matches = NULL;
+ for (tag_t *tag = tags; tag; tag = tag->next)
+ matches = new(match_t, .tag=tag, .handled=false, .next=matches);
- tag_t * const tags = Match(subject_t, EnumType)->tags;
-
- typedef struct match_s {
- const char *name;
- type_t *type;
- bool handled;
- struct match_s *next;
- } match_t;
- match_t *matches = NULL;
- for (tag_t *tag = tags; tag; tag = tag->next)
- matches = new(match_t, .name=tag->name, .type=tag->type, .next=matches);
-
- for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
- const char *tag_name = Match(clause->tag_name, Var)->name;
- type_t *tag_type = NULL;
- CORD valid_tags = CORD_EMPTY;
- for (match_t *m = matches; m; m = m->next) {
- if (streq(m->name, tag_name)) {
- if (m->handled)
- code_err(clause->tag_name, "This tag was already handled earlier");
- m->handled = true;
- tag_type = m->type;
- break;
+ for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
+ const char *tag_name = Match(clause->tag_name, Var)->name;
+ CORD valid_tags = CORD_EMPTY;
+ for (match_t *m = matches; m; m = m->next) {
+ if (streq(m->tag->name, tag_name)) {
+ if (m->handled)
+ code_err(clause->tag_name, "This tag was already handled earlier");
+ m->handled = true;
+ goto found_matching_tag;
+ }
+ if (valid_tags) valid_tags = CORD_cat(valid_tags, ", ");
+ valid_tags = CORD_cat(valid_tags, m->tag->name);
}
- if (valid_tags) valid_tags = CORD_cat(valid_tags, ", ");
- valid_tags = CORD_cat(valid_tags, m->name);
- }
- if (!tag_type)
code_err(clause->tag_name, "There is no tag '%s' for the type %T (valid tags: %s)",
tag_name, subject_t, CORD_to_char_star(valid_tags));
+ found_matching_tag:;
+ }
- env_t *scope = env;
- auto tag_struct = Match(tag_type, StructType);
- if (clause->args && !clause->args->next && tag_struct->fields && tag_struct->fields->next) {
- scope = fresh_scope(scope);
- set_binding(scope, Match(clause->args->ast, Var)->name, new(binding_t, .type=tag_type));
- } else if (clause->args) {
- scope = fresh_scope(scope);
- ast_list_t *var = clause->args;
- arg_t *field = tag_struct->fields;
- while (var || field) {
- if (!var)
- code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, tag_name, field->name);
- if (!field)
- code_err(var->ast, "This is one more field than %T has", subject_t);
- set_binding(scope, Match(var->ast, Var)->name, new(binding_t, .type=field->type));
- var = var->next;
- field = field->next;
- }
+ for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
+ type_t *clause_type = get_clause_type(env, subject_t, clause);
+ type_t *merged = type_or_type(overall_t, clause_type);
+ if (!merged)
+ code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T",
+ clause_type, overall_t);
+ overall_t = merged;
}
- type_t *clause_type = get_type(scope, clause->body);
- type_t *merged = type_or_type(overall_t, clause_type);
- if (!merged)
- code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T",
- clause_type, overall_t);
- overall_t = merged;
- }
-
- if (when->else_body) {
- bool any_unhandled = false;
- for (match_t *m = matches; m; m = m->next) {
- if (!m->handled) {
- any_unhandled = true;
- break;
+
+ if (when->else_body) {
+ bool any_unhandled = false;
+ for (match_t *m = matches; m; m = m->next) {
+ if (!m->handled) {
+ any_unhandled = true;
+ break;
+ }
+ }
+ // HACK: `while when ...` is handled by the parser adding an implicit
+ // `else: stop`, which has an empty source code span.
+ if (!any_unhandled && when->else_body->end > when->else_body->start)
+ code_err(when->else_body, "This 'else' block will never run because every tag is handled");
+
+ type_t *else_t = get_type(env, when->else_body);
+ type_t *merged = type_or_type(overall_t, else_t);
+ if (!merged)
+ code_err(when->else_body,
+ "I was expecting this block to have a %T value (based on earlier clauses), but it actually has a %T value.",
+ overall_t, else_t);
+ return merged;
+ } else {
+ CORD unhandled = CORD_EMPTY;
+ for (match_t *m = matches; m; m = m->next) {
+ if (!m->handled)
+ unhandled = unhandled ? CORD_all(unhandled, ", ", m->tag->name) : m->tag->name;
}
+ if (unhandled)
+ code_err(ast, "This 'when' statement doesn't handle the tags: %s", CORD_to_const_char_star(unhandled));
+ return overall_t;
}
- // HACK: `while when ...` is handled by the parser adding an implicit
- // `else: stop`, which has an empty source code span.
- if (!any_unhandled && when->else_body->end > when->else_body->start)
- code_err(when->else_body, "This 'else' block will never run because every tag is handled");
-
- type_t *else_t = get_type(env, when->else_body);
- type_t *merged = type_or_type(overall_t, else_t);
- if (!merged)
- code_err(when->else_body,
- "I was expecting this block to have a %T value (based on earlier clauses), but it actually has a %T value.",
- overall_t, else_t);
- return merged;
} else {
- CORD unhandled = CORD_EMPTY;
- for (match_t *m = matches; m; m = m->next) {
- if (!m->handled)
- unhandled = unhandled ? CORD_all(unhandled, ", ", m->name) : m->name;
- }
- if (unhandled)
- code_err(ast, "This 'when' statement doesn't handle the tag(s): %s", CORD_to_const_char_star(unhandled));
- return overall_t;
+ code_err(when->subject, "'when' statements are only for enum types and optional pointers, not %T", subject_t);
}
}
diff --git a/typecheck.h b/typecheck.h
index ea4fe0e3..bca21ba8 100644
--- a/typecheck.h
+++ b/typecheck.h
@@ -21,6 +21,7 @@ type_t *get_file_type(env_t *env, const char *path);
type_t *get_function_def_type(env_t *env, ast_t *ast);
type_t *get_arg_type(env_t *env, arg_t *arg);
type_t *get_arg_ast_type(env_t *env, arg_ast_t *arg);
+type_t *get_clause_type(env_t *env, type_t *subject_t, when_clause_t *clause);
bool can_be_mutated(env_t *env, ast_t *ast);
type_t *parse_type_string(env_t *env, const char *str);
type_t *get_method_type(env_t *env, ast_t *self, const char *name);