diff options
| author | Bruce Hill <bruce@bruce-hill.com> | 2024-09-11 01:31:31 -0400 |
|---|---|---|
| committer | Bruce Hill <bruce@bruce-hill.com> | 2024-09-11 01:31:31 -0400 |
| commit | 7126755275f12e6278031e78ff33f65801b133dd (patch) | |
| tree | 7f43f3449eb7bb69b0879dd41eb174e89fdc34cc /typecheck.c | |
| parent | 89234e34e292861fccb8e5bdbefc695a7e443eea (diff) | |
Add optional types
Diffstat (limited to 'typecheck.c')
| -rw-r--r-- | typecheck.c | 298 |
1 files changed, 128 insertions, 170 deletions
diff --git a/typecheck.c b/typecheck.c index 5d3ffb48..d82ddcbf 100644 --- a/typecheck.c +++ b/typecheck.c @@ -43,7 +43,7 @@ type_t *parse_type_ast(env_t *env, type_ast_t *ast) type_t *pointed_t = parse_type_ast(env, ptr->pointed); if (pointed_t->tag == VoidType) code_err(ast, "Void pointers are not supported. You probably meant 'Memory' instead of 'Void'"); - return Type(PointerType, .is_optional=ptr->is_optional, .pointed=pointed_t, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly); + return Type(PointerType, .pointed=pointed_t, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly); } case ArrayTypeAST: { type_ast_t *item_type = Match(ast, ArrayTypeAST)->item; @@ -109,6 +109,13 @@ type_t *parse_type_ast(env_t *env, type_ast_t *ast) REVERSE_LIST(type_args); return Type(ClosureType, Type(FunctionType, .args=type_args, .ret=ret_t)); } + case OptionalTypeAST: { + auto opt = Match(ast, OptionalTypeAST); + type_t *t = parse_type_ast(env, opt->type); + if (t->tag == VoidType || t->tag == AbortType || t->tag == ReturnType) + code_err(ast, "Optional %T types are not supported.", t); + return Type(OptionalType, .type=t); + } case UnknownTypeAST: code_err(ast, "I don't know how to get this type"); } errx(1, "Unreachable"); @@ -438,60 +445,43 @@ type_t *get_method_type(env_t *env, ast_t *self, const char *name) type_t *get_clause_type(env_t *env, type_t *subject_t, when_clause_t *clause) { - if (subject_t->tag == PointerType) { - if (!Match(subject_t, PointerType)->is_optional) - code_err(clause->body, "This %T pointer type is not optional, so this 'when' statement is tautological", subject_t); - - const char *tag_name = Match(clause->tag_name, Var)->name; - if (!streq(tag_name, "@")) - code_err(clause->tag_name, "'when' clauses on optional pointers only support @var, not tags like '%s'", tag_name); - - assert(clause->args); - env_t *scope = fresh_scope(env); - auto ptr = Match(subject_t, PointerType); - set_binding(scope, Match(clause->args->ast, Var)->name, - new(binding_t, .type=Type(PointerType, .pointed=ptr->pointed, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly))); - - return get_type(scope, clause->body); - } else { - assert(subject_t->tag == EnumType); - tag_t * const tags = Match(subject_t, EnumType)->tags; - - const char *tag_name = Match(clause->tag_name, Var)->name; - type_t *tag_type = NULL; - for (tag_t *tag = tags; tag; tag = tag->next) { - if (streq(tag->name, tag_name)) { - tag_type = tag->type; - break; - } + assert(subject_t->tag == EnumType); + tag_t * const tags = Match(subject_t, EnumType)->tags; + + const char *tag_name = Match(clause->tag_name, Var)->name; + type_t *tag_type = NULL; + for (tag_t *tag = tags; tag; tag = tag->next) { + if (streq(tag->name, tag_name)) { + tag_type = tag->type; + break; } + } - if (!tag_type) - code_err(clause->tag_name, "There is no tag '%s' for the type %T", tag_name, subject_t); + if (!tag_type) + code_err(clause->tag_name, "There is no tag '%s' for the type %T", tag_name, subject_t); - // Don't return early so we validate the tags - if (!clause->args) - return get_type(env, clause->body); + // Don't return early so we validate the tags + if (!clause->args) + return get_type(env, clause->body); - env_t *scope = fresh_scope(env); - auto tag_struct = Match(tag_type, StructType); - if (!clause->args->next && tag_struct->fields && tag_struct->fields->next) { - set_binding(scope, Match(clause->args->ast, Var)->name, new(binding_t, .type=tag_type)); - } else { - ast_list_t *var = clause->args; - arg_t *field = tag_struct->fields; - while (var || field) { - if (!var) - code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, tag_name, field->name); - if (!field) - code_err(var->ast, "This is one more field than %T has", subject_t); - set_binding(scope, Match(var->ast, Var)->name, new(binding_t, .type=field->type)); - var = var->next; - field = field->next; - } - } - return get_type(scope, clause->body); - } + env_t *scope = fresh_scope(env); + auto tag_struct = Match(tag_type, StructType); + if (!clause->args->next && tag_struct->fields && tag_struct->fields->next) { + set_binding(scope, Match(clause->args->ast, Var)->name, new(binding_t, .type=tag_type)); + } else { + ast_list_t *var = clause->args; + arg_t *field = tag_struct->fields; + while (var || field) { + if (!var) + code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, tag_name, field->name); + if (!field) + code_err(var->ast, "This is one more field than %T has", subject_t); + set_binding(scope, Match(var->ast, Var)->name, new(binding_t, .type=field->type)); + var = var->next; + field = field->next; + } + } + return get_type(scope, clause->body); } type_t *get_type(env_t *env, ast_t *ast) @@ -500,10 +490,7 @@ type_t *get_type(env_t *env, ast_t *ast) switch (ast->tag) { case Nil: { type_t *t = parse_type_ast(env, Match(ast, Nil)->type); - if (t->tag != PointerType) - code_err(ast, "This type is not a pointer type, so it doesn't work with a '!' nil expression"); - auto ptr = Match(t, PointerType); - return Type(PointerType, .is_optional=true, .pointed=ptr->pointed, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly); + return Type(OptionalType, .type=t); } case Bool: { return Type(BoolType); @@ -533,7 +520,7 @@ type_t *get_type(env_t *env, ast_t *ast) type_t *pointed = get_type(env, Match(ast, HeapAllocate)->value); if (has_stack_memory(pointed)) code_err(ast, "Stack references cannot be moved to the heap because they may outlive the stack frame they were created in."); - return Type(PointerType, .is_optional=false, .pointed=pointed); + return Type(PointerType, .pointed=pointed); } case StackReference: { // Supported: @@ -558,10 +545,10 @@ type_t *get_type(env_t *env, ast_t *ast) type_t *ref_type = get_type(env, value); type_t *base_type = get_type(env, base); - if (base_type->tag == PointerType) { + if (base_type->tag == OptionalType) { + code_err(base, "This value might be null, so it can't be safely dereferenced"); + } else if (base_type->tag == PointerType) { auto ptr = Match(base_type, PointerType); - if (ptr->is_optional) - code_err(base, "This value might be null, so it can't be safely dereferenced"); return Type(PointerType, .pointed=ref_type, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly); } else if (base->tag == Var) { return Type(PointerType, .pointed=ref_type, .is_stack=true); @@ -573,12 +560,9 @@ type_t *get_type(env_t *env, ast_t *ast) case Optional: { ast_t *value = Match(ast, Optional)->value; type_t *t = get_type(env, value); - if (t->tag != PointerType) - code_err(ast, "This value is not a pointer, it has type %T, so it can't be optional", t); - auto ptr = Match(t, PointerType); - if (ptr->is_optional) + if (t->tag == OptionalType) code_err(ast, "This value is already optional, it can't be converted to optional"); - return Type(PointerType, .pointed=ptr->pointed, .is_optional=true, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly); + return Type(OptionalType, .type=t); } case TextLiteral: return TEXT_TYPE; case TextJoin: { @@ -740,12 +724,11 @@ type_t *get_type(env_t *env, ast_t *ast) case Index: { auto indexing = Match(ast, Index); type_t *indexed_t = get_type(env, indexing->indexed); - if (indexed_t->tag == PointerType && !indexing->index) { - auto ptr = Match(indexed_t, PointerType); - if (ptr->is_optional) - code_err(ast, "You're attempting to dereference a pointer whose type indicates it could be nil"); - return ptr->pointed; - } + if (indexed_t->tag == OptionalType && !indexing->index) + code_err(ast, "You're attempting to dereference a value whose type indicates it could be nil"); + + if (indexed_t->tag == PointerType && !indexing->index) + return Match(indexed_t, PointerType)->pointed; type_t *value_t = value_type(indexed_t); if (value_t->tag == ArrayType) { @@ -791,7 +774,7 @@ type_t *get_type(env_t *env, ast_t *ast) else if (streq(call->name, "clear")) return Type(VoidType); else if (streq(call->name, "counts")) return Type(TableType, .key_type=item_type, .value_type=INT_TYPE); else if (streq(call->name, "find")) return INT_TYPE; - else if (streq(call->name, "first")) return Type(PointerType, .pointed=item_type, .is_optional=true, .is_readonly=true); + else if (streq(call->name, "first")) return Type(OptionalType, .type=Type(PointerType, .pointed=item_type, .is_readonly=true)); else if (streq(call->name, "from")) return self_value_t; else if (streq(call->name, "has")) return Type(BoolType); else if (streq(call->name, "heap_pop")) return item_type; @@ -845,8 +828,7 @@ type_t *get_type(env_t *env, ast_t *ast) code_err(ast, "The table method :get_or_null() is only supported for tables whose value type is a pointer, not %T", table->value_type); auto ptr = Match(table->value_type, PointerType); - return Type(PointerType, .pointed=ptr->pointed, .is_stack=ptr->is_stack, - .is_readonly=ptr->is_readonly, .is_optional=true); + return Type(OptionalType, .type=Type(PointerType, .pointed=ptr->pointed, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly)); } else if (streq(call->name, "has")) return Type(BoolType); else if (streq(call->name, "remove")) return Type(VoidType); else if (streq(call->name, "set")) return Type(VoidType); @@ -939,7 +921,7 @@ type_t *get_type(env_t *env, ast_t *ast) type_t *t = get_type(env, Match(ast, Not)->value); if (t->tag == IntType || t->tag == NumType || t->tag == BoolType) return t; - if (t->tag == PointerType && Match(t, PointerType)->is_optional) + if (t->tag == OptionalType) return Type(BoolType); ast_t *value = Match(ast, Not)->value; @@ -1005,12 +987,14 @@ type_t *get_type(env_t *env, ast_t *ast) return lhs_t; } else if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) { return lhs_t; + } else if (rhs_t->tag == OptionalType) { + if (can_promote(lhs_t, rhs_t)) + return rhs_t; } else if (lhs_t->tag == PointerType && rhs_t->tag == PointerType) { auto lhs_ptr = Match(lhs_t, PointerType); auto rhs_ptr = Match(rhs_t, PointerType); if (type_eq(lhs_ptr->pointed, rhs_ptr->pointed)) - return Type(PointerType, .pointed=lhs_ptr->pointed, .is_optional=lhs_ptr->is_optional || rhs_ptr->is_optional, - .is_readonly=lhs_ptr->is_readonly || rhs_ptr->is_readonly); + return Type(PointerType, .pointed=lhs_ptr->pointed, .is_readonly=lhs_ptr->is_readonly || rhs_ptr->is_readonly); } else if (is_int_type(lhs_t) && is_int_type(rhs_t)) { return get_math_type(env, ast, lhs_t, rhs_t); } @@ -1024,15 +1008,17 @@ type_t *get_type(env_t *env, ast_t *ast) return lhs_t; } else if (is_int_type(lhs_t) && is_int_type(rhs_t)) { return get_math_type(env, ast, lhs_t, rhs_t); + } else if (lhs_t->tag == OptionalType) { + if (can_promote(rhs_t, lhs_t)) + return rhs_t; } else if (lhs_t->tag == PointerType) { auto lhs_ptr = Match(lhs_t, PointerType); if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) { - return Type(PointerType, .pointed=lhs_ptr->pointed, .is_optional=false, .is_readonly=lhs_ptr->is_readonly); + return Type(PointerType, .pointed=lhs_ptr->pointed, .is_readonly=lhs_ptr->is_readonly); } else if (rhs_t->tag == PointerType) { auto rhs_ptr = Match(rhs_t, PointerType); if (type_eq(rhs_ptr->pointed, lhs_ptr->pointed)) - return Type(PointerType, .pointed=lhs_ptr->pointed, .is_optional=lhs_ptr->is_optional && rhs_ptr->is_optional, - .is_readonly=lhs_ptr->is_readonly || rhs_ptr->is_readonly); + return Type(PointerType, .pointed=lhs_ptr->pointed, .is_readonly=lhs_ptr->is_readonly || rhs_ptr->is_readonly); } } code_err(ast, "I can't figure out the type of this `or` expression because the left side is a %T, but the right side is a %T", @@ -1157,7 +1143,6 @@ type_t *get_type(env_t *env, ast_t *ast) REVERSE_LIST(args); type_t *ret = get_type(scope, lambda->body); - assert(ret); if (ret->tag == ReturnType) ret = Match(ret, ReturnType)->ret; if (ret->tag == AbortType) @@ -1192,104 +1177,77 @@ type_t *get_type(env_t *env, ast_t *ast) auto when = Match(ast, When); type_t *subject_t = get_type(env, when->subject); type_t *overall_t = NULL; - if (subject_t->tag == PointerType) { - if (!Match(subject_t, PointerType)->is_optional) - code_err(when->subject, "This %T pointer type is not optional, so this 'when' statement is tautological", subject_t); - - bool handled_at = false; - for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { - const char *tag_name = Match(clause->tag_name, Var)->name; - if (!streq(tag_name, "@")) - code_err(clause->tag_name, "'when' clauses on optional pointers only support @var, not tags like '%s'", tag_name); - if (handled_at) - code_err(clause->tag_name, "This 'when' statement has already handled the case of non-null pointers!"); - handled_at = true; - - assert(clause->args); - type_t *clause_type = get_clause_type(env, subject_t, clause); - type_t *merged = type_or_type(overall_t, clause_type); - if (!merged) - code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T", - clause_type, overall_t); - overall_t = merged; - } - if (!handled_at) - code_err(ast, "This 'when' statement doesn't handle non-null pointers"); - if (!when->else_body) - code_err(ast, "This 'when' statement doesn't handle null pointers"); - return overall_t; - } else if (subject_t->tag == EnumType) { - tag_t * const tags = Match(subject_t, EnumType)->tags; - - typedef struct match_s { - tag_t *tag; - bool handled; - struct match_s *next; - } match_t; - match_t *matches = NULL; - for (tag_t *tag = tags; tag; tag = tag->next) - matches = new(match_t, .tag=tag, .handled=false, .next=matches); - - for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { - const char *tag_name = Match(clause->tag_name, Var)->name; - CORD valid_tags = CORD_EMPTY; - for (match_t *m = matches; m; m = m->next) { - if (streq(m->tag->name, tag_name)) { - if (m->handled) - code_err(clause->tag_name, "This tag was already handled earlier"); - m->handled = true; - goto found_matching_tag; - } - if (valid_tags) valid_tags = CORD_cat(valid_tags, ", "); - valid_tags = CORD_cat(valid_tags, m->tag->name); - } + if (subject_t->tag != EnumType) + code_err(when->subject, "'when' statements are only for enum types and optional pointers, not %T", subject_t); - code_err(clause->tag_name, "There is no tag '%s' for the type %T (valid tags: %s)", - tag_name, subject_t, CORD_to_char_star(valid_tags)); - found_matching_tag:; - } + tag_t * const tags = Match(subject_t, EnumType)->tags; - for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { - type_t *clause_type = get_clause_type(env, subject_t, clause); - type_t *merged = type_or_type(overall_t, clause_type); - if (!merged) - code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T", - clause_type, overall_t); - overall_t = merged; + typedef struct match_s { + tag_t *tag; + bool handled; + struct match_s *next; + } match_t; + match_t *matches = NULL; + for (tag_t *tag = tags; tag; tag = tag->next) + matches = new(match_t, .tag=tag, .handled=false, .next=matches); + + for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { + const char *tag_name = Match(clause->tag_name, Var)->name; + CORD valid_tags = CORD_EMPTY; + for (match_t *m = matches; m; m = m->next) { + if (streq(m->tag->name, tag_name)) { + if (m->handled) + code_err(clause->tag_name, "This tag was already handled earlier"); + m->handled = true; + goto found_matching_tag; + } + if (valid_tags) valid_tags = CORD_cat(valid_tags, ", "); + valid_tags = CORD_cat(valid_tags, m->tag->name); } - if (when->else_body) { - bool any_unhandled = false; - for (match_t *m = matches; m; m = m->next) { - if (!m->handled) { - any_unhandled = true; - break; - } - } - // HACK: `while when ...` is handled by the parser adding an implicit - // `else: stop`, which has an empty source code span. - if (!any_unhandled && when->else_body->end > when->else_body->start) - code_err(when->else_body, "This 'else' block will never run because every tag is handled"); + code_err(clause->tag_name, "There is no tag '%s' for the type %T (valid tags: %s)", + tag_name, subject_t, CORD_to_char_star(valid_tags)); + found_matching_tag:; + } - type_t *else_t = get_type(env, when->else_body); - type_t *merged = type_or_type(overall_t, else_t); - if (!merged) - code_err(when->else_body, - "I was expecting this block to have a %T value (based on earlier clauses), but it actually has a %T value.", - overall_t, else_t); - return merged; - } else { - CORD unhandled = CORD_EMPTY; - for (match_t *m = matches; m; m = m->next) { - if (!m->handled) - unhandled = unhandled ? CORD_all(unhandled, ", ", m->tag->name) : m->tag->name; + for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { + type_t *clause_type = get_clause_type(env, subject_t, clause); + type_t *merged = type_or_type(overall_t, clause_type); + if (!merged) + code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T", + clause_type, overall_t); + overall_t = merged; + } + + if (when->else_body) { + bool any_unhandled = false; + for (match_t *m = matches; m; m = m->next) { + if (!m->handled) { + any_unhandled = true; + break; } - if (unhandled) - code_err(ast, "This 'when' statement doesn't handle the tags: %s", CORD_to_const_char_star(unhandled)); - return overall_t; } + // HACK: `while when ...` is handled by the parser adding an implicit + // `else: stop`, which has an empty source code span. + if (!any_unhandled && when->else_body->end > when->else_body->start) + code_err(when->else_body, "This 'else' block will never run because every tag is handled"); + + type_t *else_t = get_type(env, when->else_body); + type_t *merged = type_or_type(overall_t, else_t); + if (!merged) + code_err(when->else_body, + "I was expecting this block to have a %T value (based on earlier clauses), but it actually has a %T value.", + overall_t, else_t); + return merged; } else { - code_err(when->subject, "'when' statements are only for enum types and optional pointers, not %T", subject_t); + CORD unhandled = CORD_EMPTY; + for (match_t *m = matches; m; m = m->next) { + if (!m->handled) + unhandled = unhandled ? CORD_all(unhandled, ", ", m->tag->name) : m->tag->name; + } + if (unhandled) + code_err(ast, "This 'when' statement doesn't handle the tags: %s", CORD_to_const_char_star(unhandled)); + return overall_t; } } |
