aboutsummaryrefslogtreecommitdiff
path: root/typecheck.c
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2024-09-11 01:31:31 -0400
committerBruce Hill <bruce@bruce-hill.com>2024-09-11 01:31:31 -0400
commit7126755275f12e6278031e78ff33f65801b133dd (patch)
tree7f43f3449eb7bb69b0879dd41eb174e89fdc34cc /typecheck.c
parent89234e34e292861fccb8e5bdbefc695a7e443eea (diff)
Add optional types
Diffstat (limited to 'typecheck.c')
-rw-r--r--typecheck.c298
1 files changed, 128 insertions, 170 deletions
diff --git a/typecheck.c b/typecheck.c
index 5d3ffb48..d82ddcbf 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -43,7 +43,7 @@ type_t *parse_type_ast(env_t *env, type_ast_t *ast)
type_t *pointed_t = parse_type_ast(env, ptr->pointed);
if (pointed_t->tag == VoidType)
code_err(ast, "Void pointers are not supported. You probably meant 'Memory' instead of 'Void'");
- return Type(PointerType, .is_optional=ptr->is_optional, .pointed=pointed_t, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly);
+ return Type(PointerType, .pointed=pointed_t, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly);
}
case ArrayTypeAST: {
type_ast_t *item_type = Match(ast, ArrayTypeAST)->item;
@@ -109,6 +109,13 @@ type_t *parse_type_ast(env_t *env, type_ast_t *ast)
REVERSE_LIST(type_args);
return Type(ClosureType, Type(FunctionType, .args=type_args, .ret=ret_t));
}
+ case OptionalTypeAST: {
+ auto opt = Match(ast, OptionalTypeAST);
+ type_t *t = parse_type_ast(env, opt->type);
+ if (t->tag == VoidType || t->tag == AbortType || t->tag == ReturnType)
+ code_err(ast, "Optional %T types are not supported.", t);
+ return Type(OptionalType, .type=t);
+ }
case UnknownTypeAST: code_err(ast, "I don't know how to get this type");
}
errx(1, "Unreachable");
@@ -438,60 +445,43 @@ type_t *get_method_type(env_t *env, ast_t *self, const char *name)
type_t *get_clause_type(env_t *env, type_t *subject_t, when_clause_t *clause)
{
- if (subject_t->tag == PointerType) {
- if (!Match(subject_t, PointerType)->is_optional)
- code_err(clause->body, "This %T pointer type is not optional, so this 'when' statement is tautological", subject_t);
-
- const char *tag_name = Match(clause->tag_name, Var)->name;
- if (!streq(tag_name, "@"))
- code_err(clause->tag_name, "'when' clauses on optional pointers only support @var, not tags like '%s'", tag_name);
-
- assert(clause->args);
- env_t *scope = fresh_scope(env);
- auto ptr = Match(subject_t, PointerType);
- set_binding(scope, Match(clause->args->ast, Var)->name,
- new(binding_t, .type=Type(PointerType, .pointed=ptr->pointed, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly)));
-
- return get_type(scope, clause->body);
- } else {
- assert(subject_t->tag == EnumType);
- tag_t * const tags = Match(subject_t, EnumType)->tags;
-
- const char *tag_name = Match(clause->tag_name, Var)->name;
- type_t *tag_type = NULL;
- for (tag_t *tag = tags; tag; tag = tag->next) {
- if (streq(tag->name, tag_name)) {
- tag_type = tag->type;
- break;
- }
+ assert(subject_t->tag == EnumType);
+ tag_t * const tags = Match(subject_t, EnumType)->tags;
+
+ const char *tag_name = Match(clause->tag_name, Var)->name;
+ type_t *tag_type = NULL;
+ for (tag_t *tag = tags; tag; tag = tag->next) {
+ if (streq(tag->name, tag_name)) {
+ tag_type = tag->type;
+ break;
}
+ }
- if (!tag_type)
- code_err(clause->tag_name, "There is no tag '%s' for the type %T", tag_name, subject_t);
+ if (!tag_type)
+ code_err(clause->tag_name, "There is no tag '%s' for the type %T", tag_name, subject_t);
- // Don't return early so we validate the tags
- if (!clause->args)
- return get_type(env, clause->body);
+ // Don't return early so we validate the tags
+ if (!clause->args)
+ return get_type(env, clause->body);
- env_t *scope = fresh_scope(env);
- auto tag_struct = Match(tag_type, StructType);
- if (!clause->args->next && tag_struct->fields && tag_struct->fields->next) {
- set_binding(scope, Match(clause->args->ast, Var)->name, new(binding_t, .type=tag_type));
- } else {
- ast_list_t *var = clause->args;
- arg_t *field = tag_struct->fields;
- while (var || field) {
- if (!var)
- code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, tag_name, field->name);
- if (!field)
- code_err(var->ast, "This is one more field than %T has", subject_t);
- set_binding(scope, Match(var->ast, Var)->name, new(binding_t, .type=field->type));
- var = var->next;
- field = field->next;
- }
- }
- return get_type(scope, clause->body);
- }
+ env_t *scope = fresh_scope(env);
+ auto tag_struct = Match(tag_type, StructType);
+ if (!clause->args->next && tag_struct->fields && tag_struct->fields->next) {
+ set_binding(scope, Match(clause->args->ast, Var)->name, new(binding_t, .type=tag_type));
+ } else {
+ ast_list_t *var = clause->args;
+ arg_t *field = tag_struct->fields;
+ while (var || field) {
+ if (!var)
+ code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, tag_name, field->name);
+ if (!field)
+ code_err(var->ast, "This is one more field than %T has", subject_t);
+ set_binding(scope, Match(var->ast, Var)->name, new(binding_t, .type=field->type));
+ var = var->next;
+ field = field->next;
+ }
+ }
+ return get_type(scope, clause->body);
}
type_t *get_type(env_t *env, ast_t *ast)
@@ -500,10 +490,7 @@ type_t *get_type(env_t *env, ast_t *ast)
switch (ast->tag) {
case Nil: {
type_t *t = parse_type_ast(env, Match(ast, Nil)->type);
- if (t->tag != PointerType)
- code_err(ast, "This type is not a pointer type, so it doesn't work with a '!' nil expression");
- auto ptr = Match(t, PointerType);
- return Type(PointerType, .is_optional=true, .pointed=ptr->pointed, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly);
+ return Type(OptionalType, .type=t);
}
case Bool: {
return Type(BoolType);
@@ -533,7 +520,7 @@ type_t *get_type(env_t *env, ast_t *ast)
type_t *pointed = get_type(env, Match(ast, HeapAllocate)->value);
if (has_stack_memory(pointed))
code_err(ast, "Stack references cannot be moved to the heap because they may outlive the stack frame they were created in.");
- return Type(PointerType, .is_optional=false, .pointed=pointed);
+ return Type(PointerType, .pointed=pointed);
}
case StackReference: {
// Supported:
@@ -558,10 +545,10 @@ type_t *get_type(env_t *env, ast_t *ast)
type_t *ref_type = get_type(env, value);
type_t *base_type = get_type(env, base);
- if (base_type->tag == PointerType) {
+ if (base_type->tag == OptionalType) {
+ code_err(base, "This value might be null, so it can't be safely dereferenced");
+ } else if (base_type->tag == PointerType) {
auto ptr = Match(base_type, PointerType);
- if (ptr->is_optional)
- code_err(base, "This value might be null, so it can't be safely dereferenced");
return Type(PointerType, .pointed=ref_type, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly);
} else if (base->tag == Var) {
return Type(PointerType, .pointed=ref_type, .is_stack=true);
@@ -573,12 +560,9 @@ type_t *get_type(env_t *env, ast_t *ast)
case Optional: {
ast_t *value = Match(ast, Optional)->value;
type_t *t = get_type(env, value);
- if (t->tag != PointerType)
- code_err(ast, "This value is not a pointer, it has type %T, so it can't be optional", t);
- auto ptr = Match(t, PointerType);
- if (ptr->is_optional)
+ if (t->tag == OptionalType)
code_err(ast, "This value is already optional, it can't be converted to optional");
- return Type(PointerType, .pointed=ptr->pointed, .is_optional=true, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly);
+ return Type(OptionalType, .type=t);
}
case TextLiteral: return TEXT_TYPE;
case TextJoin: {
@@ -740,12 +724,11 @@ type_t *get_type(env_t *env, ast_t *ast)
case Index: {
auto indexing = Match(ast, Index);
type_t *indexed_t = get_type(env, indexing->indexed);
- if (indexed_t->tag == PointerType && !indexing->index) {
- auto ptr = Match(indexed_t, PointerType);
- if (ptr->is_optional)
- code_err(ast, "You're attempting to dereference a pointer whose type indicates it could be nil");
- return ptr->pointed;
- }
+ if (indexed_t->tag == OptionalType && !indexing->index)
+ code_err(ast, "You're attempting to dereference a value whose type indicates it could be nil");
+
+ if (indexed_t->tag == PointerType && !indexing->index)
+ return Match(indexed_t, PointerType)->pointed;
type_t *value_t = value_type(indexed_t);
if (value_t->tag == ArrayType) {
@@ -791,7 +774,7 @@ type_t *get_type(env_t *env, ast_t *ast)
else if (streq(call->name, "clear")) return Type(VoidType);
else if (streq(call->name, "counts")) return Type(TableType, .key_type=item_type, .value_type=INT_TYPE);
else if (streq(call->name, "find")) return INT_TYPE;
- else if (streq(call->name, "first")) return Type(PointerType, .pointed=item_type, .is_optional=true, .is_readonly=true);
+ else if (streq(call->name, "first")) return Type(OptionalType, .type=Type(PointerType, .pointed=item_type, .is_readonly=true));
else if (streq(call->name, "from")) return self_value_t;
else if (streq(call->name, "has")) return Type(BoolType);
else if (streq(call->name, "heap_pop")) return item_type;
@@ -845,8 +828,7 @@ type_t *get_type(env_t *env, ast_t *ast)
code_err(ast, "The table method :get_or_null() is only supported for tables whose value type is a pointer, not %T",
table->value_type);
auto ptr = Match(table->value_type, PointerType);
- return Type(PointerType, .pointed=ptr->pointed, .is_stack=ptr->is_stack,
- .is_readonly=ptr->is_readonly, .is_optional=true);
+ return Type(OptionalType, .type=Type(PointerType, .pointed=ptr->pointed, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly));
} else if (streq(call->name, "has")) return Type(BoolType);
else if (streq(call->name, "remove")) return Type(VoidType);
else if (streq(call->name, "set")) return Type(VoidType);
@@ -939,7 +921,7 @@ type_t *get_type(env_t *env, ast_t *ast)
type_t *t = get_type(env, Match(ast, Not)->value);
if (t->tag == IntType || t->tag == NumType || t->tag == BoolType)
return t;
- if (t->tag == PointerType && Match(t, PointerType)->is_optional)
+ if (t->tag == OptionalType)
return Type(BoolType);
ast_t *value = Match(ast, Not)->value;
@@ -1005,12 +987,14 @@ type_t *get_type(env_t *env, ast_t *ast)
return lhs_t;
} else if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) {
return lhs_t;
+ } else if (rhs_t->tag == OptionalType) {
+ if (can_promote(lhs_t, rhs_t))
+ return rhs_t;
} else if (lhs_t->tag == PointerType && rhs_t->tag == PointerType) {
auto lhs_ptr = Match(lhs_t, PointerType);
auto rhs_ptr = Match(rhs_t, PointerType);
if (type_eq(lhs_ptr->pointed, rhs_ptr->pointed))
- return Type(PointerType, .pointed=lhs_ptr->pointed, .is_optional=lhs_ptr->is_optional || rhs_ptr->is_optional,
- .is_readonly=lhs_ptr->is_readonly || rhs_ptr->is_readonly);
+ return Type(PointerType, .pointed=lhs_ptr->pointed, .is_readonly=lhs_ptr->is_readonly || rhs_ptr->is_readonly);
} else if (is_int_type(lhs_t) && is_int_type(rhs_t)) {
return get_math_type(env, ast, lhs_t, rhs_t);
}
@@ -1024,15 +1008,17 @@ type_t *get_type(env_t *env, ast_t *ast)
return lhs_t;
} else if (is_int_type(lhs_t) && is_int_type(rhs_t)) {
return get_math_type(env, ast, lhs_t, rhs_t);
+ } else if (lhs_t->tag == OptionalType) {
+ if (can_promote(rhs_t, lhs_t))
+ return rhs_t;
} else if (lhs_t->tag == PointerType) {
auto lhs_ptr = Match(lhs_t, PointerType);
if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) {
- return Type(PointerType, .pointed=lhs_ptr->pointed, .is_optional=false, .is_readonly=lhs_ptr->is_readonly);
+ return Type(PointerType, .pointed=lhs_ptr->pointed, .is_readonly=lhs_ptr->is_readonly);
} else if (rhs_t->tag == PointerType) {
auto rhs_ptr = Match(rhs_t, PointerType);
if (type_eq(rhs_ptr->pointed, lhs_ptr->pointed))
- return Type(PointerType, .pointed=lhs_ptr->pointed, .is_optional=lhs_ptr->is_optional && rhs_ptr->is_optional,
- .is_readonly=lhs_ptr->is_readonly || rhs_ptr->is_readonly);
+ return Type(PointerType, .pointed=lhs_ptr->pointed, .is_readonly=lhs_ptr->is_readonly || rhs_ptr->is_readonly);
}
}
code_err(ast, "I can't figure out the type of this `or` expression because the left side is a %T, but the right side is a %T",
@@ -1157,7 +1143,6 @@ type_t *get_type(env_t *env, ast_t *ast)
REVERSE_LIST(args);
type_t *ret = get_type(scope, lambda->body);
- assert(ret);
if (ret->tag == ReturnType)
ret = Match(ret, ReturnType)->ret;
if (ret->tag == AbortType)
@@ -1192,104 +1177,77 @@ type_t *get_type(env_t *env, ast_t *ast)
auto when = Match(ast, When);
type_t *subject_t = get_type(env, when->subject);
type_t *overall_t = NULL;
- if (subject_t->tag == PointerType) {
- if (!Match(subject_t, PointerType)->is_optional)
- code_err(when->subject, "This %T pointer type is not optional, so this 'when' statement is tautological", subject_t);
-
- bool handled_at = false;
- for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
- const char *tag_name = Match(clause->tag_name, Var)->name;
- if (!streq(tag_name, "@"))
- code_err(clause->tag_name, "'when' clauses on optional pointers only support @var, not tags like '%s'", tag_name);
- if (handled_at)
- code_err(clause->tag_name, "This 'when' statement has already handled the case of non-null pointers!");
- handled_at = true;
-
- assert(clause->args);
- type_t *clause_type = get_clause_type(env, subject_t, clause);
- type_t *merged = type_or_type(overall_t, clause_type);
- if (!merged)
- code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T",
- clause_type, overall_t);
- overall_t = merged;
- }
- if (!handled_at)
- code_err(ast, "This 'when' statement doesn't handle non-null pointers");
- if (!when->else_body)
- code_err(ast, "This 'when' statement doesn't handle null pointers");
- return overall_t;
- } else if (subject_t->tag == EnumType) {
- tag_t * const tags = Match(subject_t, EnumType)->tags;
-
- typedef struct match_s {
- tag_t *tag;
- bool handled;
- struct match_s *next;
- } match_t;
- match_t *matches = NULL;
- for (tag_t *tag = tags; tag; tag = tag->next)
- matches = new(match_t, .tag=tag, .handled=false, .next=matches);
-
- for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
- const char *tag_name = Match(clause->tag_name, Var)->name;
- CORD valid_tags = CORD_EMPTY;
- for (match_t *m = matches; m; m = m->next) {
- if (streq(m->tag->name, tag_name)) {
- if (m->handled)
- code_err(clause->tag_name, "This tag was already handled earlier");
- m->handled = true;
- goto found_matching_tag;
- }
- if (valid_tags) valid_tags = CORD_cat(valid_tags, ", ");
- valid_tags = CORD_cat(valid_tags, m->tag->name);
- }
+ if (subject_t->tag != EnumType)
+ code_err(when->subject, "'when' statements are only for enum types and optional pointers, not %T", subject_t);
- code_err(clause->tag_name, "There is no tag '%s' for the type %T (valid tags: %s)",
- tag_name, subject_t, CORD_to_char_star(valid_tags));
- found_matching_tag:;
- }
+ tag_t * const tags = Match(subject_t, EnumType)->tags;
- for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
- type_t *clause_type = get_clause_type(env, subject_t, clause);
- type_t *merged = type_or_type(overall_t, clause_type);
- if (!merged)
- code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T",
- clause_type, overall_t);
- overall_t = merged;
+ typedef struct match_s {
+ tag_t *tag;
+ bool handled;
+ struct match_s *next;
+ } match_t;
+ match_t *matches = NULL;
+ for (tag_t *tag = tags; tag; tag = tag->next)
+ matches = new(match_t, .tag=tag, .handled=false, .next=matches);
+
+ for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
+ const char *tag_name = Match(clause->tag_name, Var)->name;
+ CORD valid_tags = CORD_EMPTY;
+ for (match_t *m = matches; m; m = m->next) {
+ if (streq(m->tag->name, tag_name)) {
+ if (m->handled)
+ code_err(clause->tag_name, "This tag was already handled earlier");
+ m->handled = true;
+ goto found_matching_tag;
+ }
+ if (valid_tags) valid_tags = CORD_cat(valid_tags, ", ");
+ valid_tags = CORD_cat(valid_tags, m->tag->name);
}
- if (when->else_body) {
- bool any_unhandled = false;
- for (match_t *m = matches; m; m = m->next) {
- if (!m->handled) {
- any_unhandled = true;
- break;
- }
- }
- // HACK: `while when ...` is handled by the parser adding an implicit
- // `else: stop`, which has an empty source code span.
- if (!any_unhandled && when->else_body->end > when->else_body->start)
- code_err(when->else_body, "This 'else' block will never run because every tag is handled");
+ code_err(clause->tag_name, "There is no tag '%s' for the type %T (valid tags: %s)",
+ tag_name, subject_t, CORD_to_char_star(valid_tags));
+ found_matching_tag:;
+ }
- type_t *else_t = get_type(env, when->else_body);
- type_t *merged = type_or_type(overall_t, else_t);
- if (!merged)
- code_err(when->else_body,
- "I was expecting this block to have a %T value (based on earlier clauses), but it actually has a %T value.",
- overall_t, else_t);
- return merged;
- } else {
- CORD unhandled = CORD_EMPTY;
- for (match_t *m = matches; m; m = m->next) {
- if (!m->handled)
- unhandled = unhandled ? CORD_all(unhandled, ", ", m->tag->name) : m->tag->name;
+ for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
+ type_t *clause_type = get_clause_type(env, subject_t, clause);
+ type_t *merged = type_or_type(overall_t, clause_type);
+ if (!merged)
+ code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T",
+ clause_type, overall_t);
+ overall_t = merged;
+ }
+
+ if (when->else_body) {
+ bool any_unhandled = false;
+ for (match_t *m = matches; m; m = m->next) {
+ if (!m->handled) {
+ any_unhandled = true;
+ break;
}
- if (unhandled)
- code_err(ast, "This 'when' statement doesn't handle the tags: %s", CORD_to_const_char_star(unhandled));
- return overall_t;
}
+ // HACK: `while when ...` is handled by the parser adding an implicit
+ // `else: stop`, which has an empty source code span.
+ if (!any_unhandled && when->else_body->end > when->else_body->start)
+ code_err(when->else_body, "This 'else' block will never run because every tag is handled");
+
+ type_t *else_t = get_type(env, when->else_body);
+ type_t *merged = type_or_type(overall_t, else_t);
+ if (!merged)
+ code_err(when->else_body,
+ "I was expecting this block to have a %T value (based on earlier clauses), but it actually has a %T value.",
+ overall_t, else_t);
+ return merged;
} else {
- code_err(when->subject, "'when' statements are only for enum types and optional pointers, not %T", subject_t);
+ CORD unhandled = CORD_EMPTY;
+ for (match_t *m = matches; m; m = m->next) {
+ if (!m->handled)
+ unhandled = unhandled ? CORD_all(unhandled, ", ", m->tag->name) : m->tag->name;
+ }
+ if (unhandled)
+ code_err(ast, "This 'when' statement doesn't handle the tags: %s", CORD_to_const_char_star(unhandled));
+ return overall_t;
}
}