diff options
| -rw-r--r-- | compile.c | 37 | ||||
| -rw-r--r-- | test/enums.tm | 8 | ||||
| -rw-r--r-- | typecheck.c | 210 | ||||
| -rw-r--r-- | typecheck.h | 1 |
4 files changed, 165 insertions, 91 deletions
@@ -2657,8 +2657,41 @@ CORD compile(env_t *env, ast_t *ast) code_err(call->fn, "This is not a function, it's a %T", fn_t); } } - case When: - code_err(ast, "'when' expressions are not yet implemented"); + case When: { + auto original = Match(ast, When); + ast_t *when_var = WrapAST(ast, Var, .name="when"); + when_clause_t *new_clauses = NULL; + type_t *subject_t = get_type(env, original->subject); + for (when_clause_t *clause = original->clauses; clause; clause = clause->next) { + type_t *clause_type = get_clause_type(env, subject_t, clause); + if (clause_type->tag == AbortType || clause_type->tag == ReturnType) { + new_clauses = new(when_clause_t, .tag_name=clause->tag_name, .args=clause->args, .body=clause->body, .next=new_clauses); + } else { + ast_t *assign = WrapAST(clause->body, Assign, + .targets=new(ast_list_t, .ast=when_var), + .values=new(ast_list_t, .ast=clause->body)); + new_clauses = new(when_clause_t, .tag_name=clause->tag_name, .args=clause->args, .body=assign, .next=new_clauses); + } + } + REVERSE_LIST(new_clauses); + ast_t *else_body = original->else_body; + if (else_body) { + type_t *clause_type = get_type(env, else_body); + if (clause_type->tag != AbortType && clause_type->tag != ReturnType) { + else_body = WrapAST(else_body, Assign, + .targets=new(ast_list_t, .ast=when_var), + .values=new(ast_list_t, .ast=else_body)); + } + } + + type_t *t = get_type(env, ast); + env_t *when_env = fresh_scope(env); + set_binding(when_env, "when", new(binding_t, .type=t, .code="when")); + return CORD_all( + "({ ", compile_declaration(t, "when"), ";\n", + compile_statement(when_env, WrapAST(ast, When, .subject=original->subject, .clauses=new_clauses, .else_body=else_body)), + "when; })"); + } case If: { auto if_ = Match(ast, If); if (!if_->else_body) diff --git a/test/enums.tm b/test/enums.tm index f4af342d..98811408 100644 --- a/test/enums.tm +++ b/test/enums.tm @@ -66,4 +66,12 @@ func main(): while when cases[i] is One(x): >> x i += 1 + + >> expr := when cases[1] is One(y): + y + 1 + else: + -1 + = 2 + + diff --git a/typecheck.c b/typecheck.c index 519d83d1..bde73ea5 100644 --- a/typecheck.c +++ b/typecheck.c @@ -433,6 +433,64 @@ type_t *get_method_type(env_t *env, ast_t *self, const char *name) return b->type; } +type_t *get_clause_type(env_t *env, type_t *subject_t, when_clause_t *clause) +{ + if (subject_t->tag == PointerType) { + if (!Match(subject_t, PointerType)->is_optional) + code_err(clause->body, "This %T pointer type is not optional, so this 'when' statement is tautological", subject_t); + + const char *tag_name = Match(clause->tag_name, Var)->name; + if (!streq(tag_name, "@")) + code_err(clause->tag_name, "'when' clauses on optional pointers only support @var, not tags like '%s'", tag_name); + + assert(clause->args); + env_t *scope = fresh_scope(env); + auto ptr = Match(subject_t, PointerType); + set_binding(scope, Match(clause->args->ast, Var)->name, + new(binding_t, .type=Type(PointerType, .pointed=ptr->pointed, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly))); + + return get_type(scope, clause->body); + } else { + assert(subject_t->tag == EnumType); + tag_t * const tags = Match(subject_t, EnumType)->tags; + + const char *tag_name = Match(clause->tag_name, Var)->name; + type_t *tag_type = NULL; + for (tag_t *tag = tags; tag; tag = tag->next) { + if (streq(tag->name, tag_name)) { + tag_type = tag->type; + break; + } + } + + if (!tag_type) + code_err(clause->tag_name, "There is no tag '%s' for the type %T", tag_name, subject_t); + + // Don't return early so we validate the tags + if (!clause->args) + return get_type(env, clause->body); + + env_t *scope = fresh_scope(env); + auto tag_struct = Match(tag_type, StructType); + if (!clause->args->next && tag_struct->fields && tag_struct->fields->next) { + set_binding(scope, Match(clause->args->ast, Var)->name, new(binding_t, .type=tag_type)); + } else { + ast_list_t *var = clause->args; + arg_t *field = tag_struct->fields; + while (var || field) { + if (!var) + code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, tag_name, field->name); + if (!field) + code_err(var->ast, "This is one more field than %T has", subject_t); + set_binding(scope, Match(var->ast, Var)->name, new(binding_t, .type=field->type)); + var = var->next; + field = field->next; + } + } + return get_type(scope, clause->body); + } +} + type_t *get_type(env_t *env, ast_t *ast) { if (!ast) return NULL; @@ -1145,12 +1203,7 @@ type_t *get_type(env_t *env, ast_t *ast) handled_at = true; assert(clause->args); - env_t *scope = fresh_scope(env); - auto ptr = Match(subject_t, PointerType); - set_binding(scope, Match(clause->args->ast, Var)->name, - new(binding_t, .type=Type(PointerType, .pointed=ptr->pointed, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly))); - - type_t *clause_type = get_type(scope, clause->body); + type_t *clause_type = get_clause_type(env, subject_t, clause); type_t *merged = type_or_type(overall_t, clause_type); if (!merged) code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T", @@ -1162,99 +1215,78 @@ type_t *get_type(env_t *env, ast_t *ast) if (!when->else_body) code_err(ast, "This 'when' statement doesn't handle null pointers"); return overall_t; - } - - if (subject_t->tag != EnumType) - code_err(when->subject, "'when' statements are only for enum types and optional pointers, not %T", subject_t); + } else if (subject_t->tag == EnumType) { + tag_t * const tags = Match(subject_t, EnumType)->tags; + + typedef struct match_s { + tag_t *tag; + bool handled; + struct match_s *next; + } match_t; + match_t *matches = NULL; + for (tag_t *tag = tags; tag; tag = tag->next) + matches = new(match_t, .tag=tag, .handled=false, .next=matches); - tag_t * const tags = Match(subject_t, EnumType)->tags; - - typedef struct match_s { - const char *name; - type_t *type; - bool handled; - struct match_s *next; - } match_t; - match_t *matches = NULL; - for (tag_t *tag = tags; tag; tag = tag->next) - matches = new(match_t, .name=tag->name, .type=tag->type, .next=matches); - - for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { - const char *tag_name = Match(clause->tag_name, Var)->name; - type_t *tag_type = NULL; - CORD valid_tags = CORD_EMPTY; - for (match_t *m = matches; m; m = m->next) { - if (streq(m->name, tag_name)) { - if (m->handled) - code_err(clause->tag_name, "This tag was already handled earlier"); - m->handled = true; - tag_type = m->type; - break; + for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { + const char *tag_name = Match(clause->tag_name, Var)->name; + CORD valid_tags = CORD_EMPTY; + for (match_t *m = matches; m; m = m->next) { + if (streq(m->tag->name, tag_name)) { + if (m->handled) + code_err(clause->tag_name, "This tag was already handled earlier"); + m->handled = true; + goto found_matching_tag; + } + if (valid_tags) valid_tags = CORD_cat(valid_tags, ", "); + valid_tags = CORD_cat(valid_tags, m->tag->name); } - if (valid_tags) valid_tags = CORD_cat(valid_tags, ", "); - valid_tags = CORD_cat(valid_tags, m->name); - } - if (!tag_type) code_err(clause->tag_name, "There is no tag '%s' for the type %T (valid tags: %s)", tag_name, subject_t, CORD_to_char_star(valid_tags)); + found_matching_tag:; + } - env_t *scope = env; - auto tag_struct = Match(tag_type, StructType); - if (clause->args && !clause->args->next && tag_struct->fields && tag_struct->fields->next) { - scope = fresh_scope(scope); - set_binding(scope, Match(clause->args->ast, Var)->name, new(binding_t, .type=tag_type)); - } else if (clause->args) { - scope = fresh_scope(scope); - ast_list_t *var = clause->args; - arg_t *field = tag_struct->fields; - while (var || field) { - if (!var) - code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, tag_name, field->name); - if (!field) - code_err(var->ast, "This is one more field than %T has", subject_t); - set_binding(scope, Match(var->ast, Var)->name, new(binding_t, .type=field->type)); - var = var->next; - field = field->next; - } + for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { + type_t *clause_type = get_clause_type(env, subject_t, clause); + type_t *merged = type_or_type(overall_t, clause_type); + if (!merged) + code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T", + clause_type, overall_t); + overall_t = merged; } - type_t *clause_type = get_type(scope, clause->body); - type_t *merged = type_or_type(overall_t, clause_type); - if (!merged) - code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T", - clause_type, overall_t); - overall_t = merged; - } - - if (when->else_body) { - bool any_unhandled = false; - for (match_t *m = matches; m; m = m->next) { - if (!m->handled) { - any_unhandled = true; - break; + + if (when->else_body) { + bool any_unhandled = false; + for (match_t *m = matches; m; m = m->next) { + if (!m->handled) { + any_unhandled = true; + break; + } + } + // HACK: `while when ...` is handled by the parser adding an implicit + // `else: stop`, which has an empty source code span. + if (!any_unhandled && when->else_body->end > when->else_body->start) + code_err(when->else_body, "This 'else' block will never run because every tag is handled"); + + type_t *else_t = get_type(env, when->else_body); + type_t *merged = type_or_type(overall_t, else_t); + if (!merged) + code_err(when->else_body, + "I was expecting this block to have a %T value (based on earlier clauses), but it actually has a %T value.", + overall_t, else_t); + return merged; + } else { + CORD unhandled = CORD_EMPTY; + for (match_t *m = matches; m; m = m->next) { + if (!m->handled) + unhandled = unhandled ? CORD_all(unhandled, ", ", m->tag->name) : m->tag->name; } + if (unhandled) + code_err(ast, "This 'when' statement doesn't handle the tags: %s", CORD_to_const_char_star(unhandled)); + return overall_t; } - // HACK: `while when ...` is handled by the parser adding an implicit - // `else: stop`, which has an empty source code span. - if (!any_unhandled && when->else_body->end > when->else_body->start) - code_err(when->else_body, "This 'else' block will never run because every tag is handled"); - - type_t *else_t = get_type(env, when->else_body); - type_t *merged = type_or_type(overall_t, else_t); - if (!merged) - code_err(when->else_body, - "I was expecting this block to have a %T value (based on earlier clauses), but it actually has a %T value.", - overall_t, else_t); - return merged; } else { - CORD unhandled = CORD_EMPTY; - for (match_t *m = matches; m; m = m->next) { - if (!m->handled) - unhandled = unhandled ? CORD_all(unhandled, ", ", m->name) : m->name; - } - if (unhandled) - code_err(ast, "This 'when' statement doesn't handle the tag(s): %s", CORD_to_const_char_star(unhandled)); - return overall_t; + code_err(when->subject, "'when' statements are only for enum types and optional pointers, not %T", subject_t); } } diff --git a/typecheck.h b/typecheck.h index ea4fe0e3..bca21ba8 100644 --- a/typecheck.h +++ b/typecheck.h @@ -21,6 +21,7 @@ type_t *get_file_type(env_t *env, const char *path); type_t *get_function_def_type(env_t *env, ast_t *ast); type_t *get_arg_type(env_t *env, arg_t *arg); type_t *get_arg_ast_type(env_t *env, arg_ast_t *arg); +type_t *get_clause_type(env_t *env, type_t *subject_t, when_clause_t *clause); bool can_be_mutated(env_t *env, ast_t *ast); type_t *parse_type_string(env_t *env, const char *str); type_t *get_method_type(env_t *env, ast_t *self, const char *name); |
