aboutsummaryrefslogtreecommitdiff
path: root/src/typecheck.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/typecheck.c')
-rw-r--r--src/typecheck.c539
1 files changed, 272 insertions, 267 deletions
diff --git a/src/typecheck.c b/src/typecheck.c
index 8a2ee32b..cd6ff1c2 100644
--- a/src/typecheck.c
+++ b/src/typecheck.c
@@ -135,27 +135,27 @@ type_t *parse_type_ast(env_t *env, type_ast_t *ast)
errx(1, "Unreachable");
}
-static PUREFUNC bool risks_zero_or_inf(ast_t *ast)
-{
- switch (ast->tag) {
- case Int: {
- const char *str = Match(ast, Int)->str;
- OptionalInt_t int_val = Int$from_str(str);
- return (int_val.small == 0x1); // zero
- }
- case Num: {
- return Match(ast, Num)->n == 0.0;
- }
- case BinaryOp: {
- auto binop = Match(ast, BinaryOp);
- if (binop->op == BINOP_MULT || binop->op == BINOP_DIVIDE || binop->op == BINOP_MIN || binop->op == BINOP_MAX)
- return risks_zero_or_inf(binop->lhs) || risks_zero_or_inf(binop->rhs);
- else
- return true;
- }
- default: return true;
- }
-}
+// static PUREFUNC bool risks_zero_or_inf(ast_t *ast)
+// {
+// switch (ast->tag) {
+// case Int: {
+// const char *str = Match(ast, Int)->str;
+// OptionalInt_t int_val = Int$from_str(str);
+// return (int_val.small == 0x1); // zero
+// }
+// case Num: {
+// return Match(ast, Num)->n == 0.0;
+// }
+// case BINOP_CASES: {
+// binary_operands_t binop = BINARY_OPERANDS(ast);
+// if (ast->tag == Multiply || ast->tag == Divide || ast->tag == Min || ast->tag == Max)
+// return risks_zero_or_inf(binop.lhs) || risks_zero_or_inf(binop.rhs);
+// else
+// return true;
+// }
+// default: return true;
+// }
+// }
PUREFUNC type_t *get_math_type(env_t *env, ast_t *ast, type_t *lhs_t, type_t *rhs_t)
{
@@ -312,7 +312,7 @@ void bind_statement(env_t *env, ast_t *statement)
if (get_binding(env, name))
code_err(decl->var, "A ", type_to_str(get_binding(env, name)->type), " called ", quoted(name), " has already been defined");
bind_statement(env, decl->value);
- type_t *type = get_type(env, decl->value);
+ type_t *type = decl->type ? parse_type_ast(env, decl->type) : get_type(env, decl->value);
if (!type)
code_err(decl->value, "I couldn't figure out the type of this value");
if (type->tag == FunctionType)
@@ -617,12 +617,7 @@ type_t *get_type(env_t *env, ast_t *ast)
#endif
switch (ast->tag) {
case None: {
- if (!Match(ast, None)->type)
- return Type(OptionalType, .type=NULL);
- type_t *t = parse_type_ast(env, Match(ast, None)->type);
- if (t->tag == OptionalType)
- code_err(ast, "Nested optional types are not supported. This should be: `none:", type_to_str(Match(t, OptionalType)->type), "`");
- return Type(OptionalType, .type=t);
+ return Type(OptionalType, .type=NULL);
}
case Bool: {
return Type(BoolType);
@@ -714,115 +709,91 @@ type_t *get_type(env_t *env, ast_t *ast)
case Array: {
auto array = Match(ast, Array);
type_t *item_type = NULL;
- if (array->item_type) {
- item_type = parse_type_ast(env, array->item_type);
- } else if (array->items) {
- for (ast_list_t *item = array->items; item; item = item->next) {
- ast_t *item_ast = item->ast;
- env_t *scope = env;
- while (item_ast->tag == Comprehension) {
- auto comp = Match(item_ast, Comprehension);
- scope = for_scope(
- scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars));
- item_ast = comp->expr;
- }
- type_t *t2 = get_type(scope, item_ast);
- type_t *merged = item_type ? type_or_type(item_type, t2) : t2;
- if (!merged)
- code_err(item->ast,
- "This array item has type ", type_to_str(t2),
- ", which is different from earlier array items which have type ", type_to_str(item_type));
- item_type = merged;
+ for (ast_list_t *item = array->items; item; item = item->next) {
+ ast_t *item_ast = item->ast;
+ env_t *scope = env;
+ while (item_ast->tag == Comprehension) {
+ auto comp = Match(item_ast, Comprehension);
+ scope = for_scope(
+ scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars));
+ item_ast = comp->expr;
}
- } else {
- code_err(ast, "I can't figure out what type this array has because it has no members or explicit type");
+ type_t *t2 = get_type(scope, item_ast);
+ type_t *merged = item_type ? type_or_type(item_type, t2) : t2;
+ if (!merged)
+ code_err(item->ast,
+ "This array item has type ", type_to_str(t2),
+ ", which is different from earlier array items which have type ", type_to_str(item_type));
+ item_type = merged;
}
- if (has_stack_memory(item_type))
- code_err(ast, "Arrays cannot hold stack references, because the array may outlive the stack frame the reference was created in.");
- if (!item_type)
- code_err(ast, "I couldn't figure out the item type for this array!");
+ if (item_type && has_stack_memory(item_type))
+ code_err(ast, "Arrays cannot hold stack references, because the array may outlive the stack frame the reference was created in.");
return Type(ArrayType, .item_type=item_type);
}
case Set: {
auto set = Match(ast, Set);
type_t *item_type = NULL;
- if (set->item_type) {
- item_type = parse_type_ast(env, set->item_type);
- } else {
- for (ast_list_t *item = set->items; item; item = item->next) {
- ast_t *item_ast = item->ast;
- env_t *scope = env;
- while (item_ast->tag == Comprehension) {
- auto comp = Match(item_ast, Comprehension);
- scope = for_scope(
- scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars));
- item_ast = comp->expr;
- }
-
- type_t *this_item_type = get_type(scope, item_ast);
- type_t *item_merged = type_or_type(item_type, this_item_type);
- if (!item_merged)
- code_err(item_ast,
- "This set item has type ", type_to_str(this_item_type),
- ", which is different from earlier set items which have type ", type_to_str(item_type));
- item_type = item_merged;
+ for (ast_list_t *item = set->items; item; item = item->next) {
+ ast_t *item_ast = item->ast;
+ env_t *scope = env;
+ while (item_ast->tag == Comprehension) {
+ auto comp = Match(item_ast, Comprehension);
+ scope = for_scope(
+ scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars));
+ item_ast = comp->expr;
}
- }
- if (!item_type)
- code_err(ast, "I couldn't figure out the item type for this set!");
+ type_t *this_item_type = get_type(scope, item_ast);
+ type_t *item_merged = type_or_type(item_type, this_item_type);
+ if (!item_merged)
+ code_err(item_ast,
+ "This set item has type ", type_to_str(this_item_type),
+ ", which is different from earlier set items which have type ", type_to_str(item_type));
+ item_type = item_merged;
+ }
- if (has_stack_memory(item_type))
+ if (item_type && has_stack_memory(item_type))
code_err(ast, "Sets cannot hold stack references because the set may outlive the reference's stack frame.");
+
return Type(SetType, .item_type=item_type);
}
case Table: {
auto table = Match(ast, Table);
type_t *key_type = NULL, *value_type = NULL;
- if (table->key_type && table->value_type) {
- key_type = parse_type_ast(env, table->key_type);
- value_type = parse_type_ast(env, table->value_type);
- } else if (table->key_type && table->default_value) {
- key_type = parse_type_ast(env, table->key_type);
- value_type = get_type(env, table->default_value);
- } else {
- for (ast_list_t *entry = table->entries; entry; entry = entry->next) {
- ast_t *entry_ast = entry->ast;
- env_t *scope = env;
- while (entry_ast->tag == Comprehension) {
- auto comp = Match(entry_ast, Comprehension);
- scope = for_scope(
- scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars));
- entry_ast = comp->expr;
- }
-
- auto e = Match(entry_ast, TableEntry);
- type_t *key_t = get_type(scope, e->key);
- type_t *value_t = get_type(scope, e->value);
-
- type_t *key_merged = key_type ? type_or_type(key_type, key_t) : key_t;
- if (!key_merged)
- code_err(entry->ast,
- "This table entry has type ", type_to_str(key_t),
- ", which is different from earlier table entries which have type ", type_to_str(key_type));
- key_type = key_merged;
-
- type_t *val_merged = value_type ? type_or_type(value_type, value_t) : value_t;
- if (!val_merged)
- code_err(entry->ast,
- "This table entry has type ", type_to_str(value_t),
- ", which is different from earlier table entries which have type ", type_to_str(value_type));
- value_type = val_merged;
+ for (ast_list_t *entry = table->entries; entry; entry = entry->next) {
+ ast_t *entry_ast = entry->ast;
+ env_t *scope = env;
+ while (entry_ast->tag == Comprehension) {
+ auto comp = Match(entry_ast, Comprehension);
+ scope = for_scope(
+ scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars));
+ entry_ast = comp->expr;
}
- }
- if (!key_type || !value_type)
- code_err(ast, "I couldn't figure out the key and value types for this table!");
+ auto e = Match(entry_ast, TableEntry);
+ type_t *key_t = get_type(scope, e->key);
+ type_t *value_t = get_type(scope, e->value);
+
+ type_t *key_merged = key_type ? type_or_type(key_type, key_t) : key_t;
+ if (!key_merged)
+ code_err(entry->ast,
+ "This table entry has type ", type_to_str(key_t),
+ ", which is different from earlier table entries which have type ", type_to_str(key_type));
+ key_type = key_merged;
- if (has_stack_memory(key_type) || has_stack_memory(value_type))
+ type_t *val_merged = value_type ? type_or_type(value_type, value_t) : value_t;
+ if (!val_merged)
+ code_err(entry->ast,
+ "This table entry has type ", type_to_str(value_t),
+ ", which is different from earlier table entries which have type ", type_to_str(value_type));
+ value_type = val_merged;
+ }
+
+ if ((key_type && has_stack_memory(key_type)) || (value_type && has_stack_memory(value_type)))
code_err(ast, "Tables cannot hold stack references because the table may outlive the reference's stack frame.");
+
return Type(TableType, .key_type=key_type, .value_type=value_type, .default_value=table->default_value, .env=env);
}
case TableEntry: {
@@ -998,7 +969,7 @@ type_t *get_type(env_t *env, ast_t *ast)
// Early out if the type is knowable without any context from the block:
switch (last->ast->tag) {
- case UpdateAssign: case Assign: case Declare: case FunctionDef: case ConvertDef: case StructDef: case EnumDef: case LangDef: case Extend:
+ case UPDATE_CASES: case Assign: case Declare: case FunctionDef: case ConvertDef: case StructDef: case EnumDef: case LangDef: case Extend:
return Type(VoidType);
default: break;
}
@@ -1022,7 +993,7 @@ type_t *get_type(env_t *env, ast_t *ast)
case Extern: {
return parse_type_ast(env, Match(ast, Extern)->type);
}
- case Declare: case Assign: case DocTest: {
+ case Declare: case Assign: case UPDATE_CASES: case DocTest: {
return Type(VoidType);
}
case Use: {
@@ -1078,169 +1049,160 @@ type_t *get_type(env_t *env, ast_t *ast)
}
code_err(ast, "I only know how to get 'not' of boolean, numeric, and optional pointer types, not ", type_to_str(t));
}
- case BinaryOp: {
- auto binop = Match(ast, BinaryOp);
- type_t *lhs_t = get_type(env, binop->lhs),
- *rhs_t = get_type(env, binop->rhs);
-
- if (lhs_t->tag == BigIntType && rhs_t->tag != BigIntType && is_numeric_type(rhs_t) && binop->lhs->tag == Int) {
- lhs_t = rhs_t;
- } else if (rhs_t->tag == BigIntType && lhs_t->tag != BigIntType && is_numeric_type(lhs_t) && binop->rhs->tag == Int) {
+ case Or: {
+ binary_operands_t binop = BINARY_OPERANDS(ast);
+ type_t *lhs_t = get_type(env, binop.lhs);
+ type_t *rhs_t = get_type(env, binop.rhs);
- rhs_t = lhs_t;
+ // `opt? or (x == y)` / `(x == y) or opt?` is a boolean conditional:
+ if ((lhs_t->tag == OptionalType && rhs_t->tag == BoolType)
+ || (lhs_t->tag == BoolType && rhs_t->tag == OptionalType)) {
+ return Type(BoolType);
}
-#define binding_works(name, self, lhs_t, rhs_t, ret_t) \
- ({ binding_t *b = get_namespace_binding(env, self, name); \
- (b && b->type->tag == FunctionType && ({ auto fn = Match(b->type, FunctionType); \
- (type_eq(fn->ret, ret_t) \
- && (fn->args && type_eq(fn->args->type, lhs_t)) \
- && (fn->args->next && can_promote(rhs_t, fn->args->next->type))); })); })
- // Check for a binop method like plus() etc:
- switch (binop->op) {
- case BINOP_MULT: {
- if (is_numeric_type(lhs_t) && binding_works("scaled_by", binop->rhs, rhs_t, lhs_t, rhs_t))
- return rhs_t;
- else if (is_numeric_type(rhs_t) && binding_works("scaled_by", binop->lhs, lhs_t, rhs_t, lhs_t))
- return lhs_t;
- else if (type_eq(lhs_t, rhs_t) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t))
+ if (lhs_t->tag == OptionalType) {
+ if (rhs_t->tag == OptionalType) {
+ type_t *result = most_complete_type(lhs_t, rhs_t);
+ if (result == NULL)
+ code_err(ast, "I could not determine the type of ", type_to_str(lhs_t), " `or` ", type_to_str(rhs_t));
+ return result;
+ } else if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) {
+ return Match(lhs_t, OptionalType)->type;
+ }
+ type_t *non_opt = Match(lhs_t, OptionalType)->type;
+ non_opt = most_complete_type(non_opt, rhs_t);
+ if (non_opt != NULL)
+ return non_opt;
+ } else if ((is_numeric_type(lhs_t) || lhs_t->tag == BoolType)
+ && (is_numeric_type(rhs_t) || rhs_t->tag == BoolType)
+ && lhs_t->tag != NumType && rhs_t->tag != NumType) {
+ if (can_promote(rhs_t, lhs_t))
return lhs_t;
- break;
+ else if (can_promote(lhs_t, rhs_t))
+ return rhs_t;
+ } else if (lhs_t->tag == SetType && rhs_t->tag == SetType && type_eq(lhs_t, rhs_t)) {
+ return lhs_t;
}
- case BINOP_PLUS: case BINOP_MINUS: case BINOP_AND: case BINOP_OR: case BINOP_XOR: case BINOP_CONCAT: {
- if (type_eq(lhs_t, rhs_t) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t))
- return lhs_t;
- break;
+ code_err(ast, "I couldn't figure out how to do `or` between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t));
+ }
+ case And: {
+ binary_operands_t binop = BINARY_OPERANDS(ast);
+ type_t *lhs_t = get_type(env, binop.lhs);
+ type_t *rhs_t = get_type(env, binop.rhs);
+
+ // `and` between optionals/bools is a boolean expression like `if opt? and opt?:` or `if x > 0 and opt?:`
+ if ((lhs_t->tag == OptionalType || lhs_t->tag == BoolType)
+ && (rhs_t->tag == OptionalType || rhs_t->tag == BoolType)) {
+ return Type(BoolType);
}
- case BINOP_DIVIDE: case BINOP_MOD: case BINOP_MOD1: {
- if (is_numeric_type(rhs_t) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t))
+
+ // Bitwise AND:
+ if ((is_numeric_type(lhs_t) || lhs_t->tag == BoolType)
+ && (is_numeric_type(rhs_t) || rhs_t->tag == BoolType)
+ && lhs_t->tag != NumType && rhs_t->tag != NumType) {
+ if (can_promote(rhs_t, lhs_t))
return lhs_t;
- break;
- }
- case BINOP_LSHIFT: case BINOP_RSHIFT: case BINOP_ULSHIFT: case BINOP_URSHIFT: {
+ else if (can_promote(lhs_t, rhs_t))
+ return rhs_t;
+ } else if (lhs_t->tag == SetType && rhs_t->tag == SetType && type_eq(lhs_t, rhs_t)) {
return lhs_t;
}
- case BINOP_POWER: {
- if (is_numeric_type(rhs_t) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t))
- return lhs_t;
- break;
- }
- default: break;
- }
-#undef binding_works
+ code_err(ast, "I couldn't figure out how to do `and` between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t));
+ }
+ case Xor: {
+ binary_operands_t binop = BINARY_OPERANDS(ast);
+ type_t *lhs_t = get_type(env, binop.lhs);
+ type_t *rhs_t = get_type(env, binop.rhs);
- switch (binop->op) {
- case BINOP_AND: {
- if (lhs_t->tag == BoolType && rhs_t->tag == BoolType) {
- return lhs_t;
- } else if ((lhs_t->tag == BoolType && rhs_t->tag == OptionalType) ||
- (lhs_t->tag == OptionalType && rhs_t->tag == BoolType)) {
- return Type(BoolType);
- } else if (lhs_t->tag == BoolType && (rhs_t->tag == AbortType || rhs_t->tag == ReturnType)) {
- return lhs_t;
- } else if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) {
- return lhs_t;
- } else if (rhs_t->tag == OptionalType) {
- if (can_promote(lhs_t, rhs_t))
- return rhs_t;
- } else if (lhs_t->tag == PointerType && rhs_t->tag == PointerType) {
- auto lhs_ptr = Match(lhs_t, PointerType);
- auto rhs_ptr = Match(rhs_t, PointerType);
- if (type_eq(lhs_ptr->pointed, rhs_ptr->pointed))
- return Type(PointerType, .pointed=lhs_ptr->pointed);
- } else if ((is_int_type(lhs_t) && is_int_type(rhs_t))
- || (lhs_t->tag == ByteType && rhs_t->tag == ByteType)) {
- return get_math_type(env, ast, lhs_t, rhs_t);
- }
- code_err(ast, "I can't figure out the type of this `and` expression between a ", type_to_str(lhs_t), " and a ", type_to_str(rhs_t));
+ // `xor` between optionals/bools is a boolean expression like `if opt? xor opt?:` or `if x > 0 xor opt?:`
+ if ((lhs_t->tag == OptionalType || lhs_t->tag == BoolType)
+ && (rhs_t->tag == OptionalType || rhs_t->tag == BoolType)) {
+ return Type(BoolType);
}
- case BINOP_OR: {
- if (lhs_t->tag == BoolType && rhs_t->tag == BoolType) {
- return lhs_t;
- } else if ((lhs_t->tag == BoolType && rhs_t->tag == OptionalType) ||
- (lhs_t->tag == OptionalType && rhs_t->tag == BoolType)) {
- return Type(BoolType);
- } else if (lhs_t->tag == BoolType && (rhs_t->tag == AbortType || rhs_t->tag == ReturnType)) {
+
+ // Bitwise XOR:
+ if ((is_numeric_type(lhs_t) || lhs_t->tag == BoolType)
+ && (is_numeric_type(rhs_t) || rhs_t->tag == BoolType)
+ && lhs_t->tag != NumType && rhs_t->tag != NumType) {
+ if (can_promote(rhs_t, lhs_t))
return lhs_t;
- } else if ((is_int_type(lhs_t) && is_int_type(rhs_t))
- || (lhs_t->tag == ByteType && rhs_t->tag == ByteType)) {
- return get_math_type(env, ast, lhs_t, rhs_t);
- } else if (lhs_t->tag == OptionalType) {
- if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType)
- return Match(lhs_t, OptionalType)->type;
- if (can_promote(rhs_t, lhs_t))
- return rhs_t;
- } else if (lhs_t->tag == PointerType) {
- auto lhs_ptr = Match(lhs_t, PointerType);
- if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) {
- return Type(PointerType, .pointed=lhs_ptr->pointed);
- } else if (rhs_t->tag == PointerType) {
- auto rhs_ptr = Match(rhs_t, PointerType);
- if (type_eq(rhs_ptr->pointed, lhs_ptr->pointed))
- return Type(PointerType, .pointed=lhs_ptr->pointed);
- }
- } else if (rhs_t->tag == OptionalType) {
- return type_or_type(lhs_t, rhs_t);
- }
- code_err(ast, "I can't figure out the type of this `or` expression between a ", type_to_str(lhs_t), " and a ", type_to_str(rhs_t));
+ else if (can_promote(lhs_t, rhs_t))
+ return rhs_t;
+ } else if (lhs_t->tag == SetType && rhs_t->tag == SetType && type_eq(lhs_t, rhs_t)) {
+ return lhs_t;
}
- case BINOP_XOR: {
- if (lhs_t->tag == BoolType && rhs_t->tag == BoolType) {
- return lhs_t;
- } else if ((lhs_t->tag == BoolType && rhs_t->tag == OptionalType) ||
- (lhs_t->tag == OptionalType && rhs_t->tag == BoolType)) {
- return Type(BoolType);
- } else if ((is_int_type(lhs_t) && is_int_type(rhs_t))
- || (lhs_t->tag == ByteType && rhs_t->tag == ByteType)) {
- return get_math_type(env, ast, lhs_t, rhs_t);
- }
+ code_err(ast, "I couldn't figure out how to do `xor` between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t));
+ }
+ case Compare: {
+ binary_operands_t binop = BINARY_OPERANDS(ast);
+ type_t *lhs_t = get_type(env, binop.lhs);
+ type_t *rhs_t = get_type(env, binop.rhs);
- code_err(ast, "I can't figure out the type of this `xor` expression between a ", type_to_str(lhs_t), " and a ", type_to_str(rhs_t));
- }
- case BINOP_CONCAT: {
- if (!type_eq(lhs_t, rhs_t))
- code_err(ast, "The type on the left side of this concatenation doesn't match the right side: ", type_to_str(lhs_t),
- " vs. ", type_to_str(rhs_t));
- if (lhs_t->tag == ArrayType || lhs_t->tag == TextType || lhs_t->tag == SetType)
- return lhs_t;
+ if (can_promote(rhs_t, lhs_t) || can_promote(lhs_t, rhs_t))
+ return Type(IntType, .bits=TYPE_IBITS32);
- code_err(ast, "Only array/set/text value types support concatenation, not ", type_to_str(lhs_t));
- }
- case BINOP_EQ: case BINOP_NE: case BINOP_LT: case BINOP_LE: case BINOP_GT: case BINOP_GE: {
- if (!can_promote(lhs_t, rhs_t) && !can_promote(rhs_t, lhs_t))
- code_err(ast, "I can't compare these two different types: ", type_to_str(lhs_t), " vs ", type_to_str(rhs_t));
+ code_err(ast, "I don't know how to compare ", type_to_str(lhs_t), " and ", type_to_str(rhs_t));
+ }
+ case Equals: case NotEquals: case LessThan: case LessThanOrEquals: case GreaterThan: case GreaterThanOrEquals: {
+ binary_operands_t binop = BINARY_OPERANDS(ast);
+ type_t *lhs_t = get_type(env, binop.lhs);
+ type_t *rhs_t = get_type(env, binop.rhs);
+ if (can_promote(rhs_t, lhs_t) || can_promote(lhs_t, rhs_t))
return Type(BoolType);
+
+ code_err(ast, "I don't know how to compare ", type_to_str(lhs_t), " and ", type_to_str(rhs_t));
+ }
+ case Power: case Multiply: case Divide: case Mod: case Mod1: case Plus: case Minus: case LeftShift:
+ case UnsignedLeftShift: case RightShift: case UnsignedRightShift: {
+ binary_operands_t binop = BINARY_OPERANDS(ast);
+ type_t *lhs_t = get_type(env, binop.lhs);
+ type_t *rhs_t = get_type(env, binop.rhs);
+
+ if (ast->tag == LeftShift || ast->tag == UnsignedLeftShift || ast->tag == RightShift || ast->tag == UnsignedRightShift) {
+ if (!is_int_type(rhs_t))
+ code_err(binop.rhs, "I only know how to do bit shifting by integer amounts, not ", type_to_str(rhs_t));
}
- case BINOP_CMP:
- return Type(IntType, .bits=TYPE_IBITS32);
- case BINOP_POWER: {
- type_t *result = get_math_type(env, ast, lhs_t, rhs_t);
- if (result->tag == NumType)
- return result;
- return Type(NumType, .bits=TYPE_NBITS64);
- }
- case BINOP_MULT: case BINOP_DIVIDE: {
- type_t *math_type = get_math_type(env, ast, value_type(lhs_t), value_type(rhs_t));
- if (value_type(lhs_t)->tag == NumType || value_type(rhs_t)->tag == NumType) {
- if (risks_zero_or_inf(binop->lhs) && risks_zero_or_inf(binop->rhs))
- return Type(OptionalType, math_type);
- else
- return math_type;
- }
- return math_type;
- }
- default: {
- return get_math_type(env, ast, lhs_t, rhs_t);
- }
+
+ type_t *overall_t = (can_promote(rhs_t, lhs_t) ? lhs_t : (can_promote(lhs_t, rhs_t) ? rhs_t : NULL));
+ if (ast->tag == Multiply || ast->tag == Divide) {
+ binding_t *b = is_numeric_type(lhs_t) ? get_metamethod_binding(env, ast->tag, binop.lhs, binop.rhs, lhs_t)
+ : get_metamethod_binding(env, ast->tag, binop.rhs, binop.lhs, rhs_t);
+ if (b) return overall_t;
+ } else {
+ if (overall_t == NULL)
+ code_err(ast, "I don't know how to do math operations between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t));
+
+ binding_t *b = get_metamethod_binding(env, ast->tag, binop.lhs, binop.rhs, overall_t);
+ if (b) return overall_t;
}
+ if (is_numeric_type(lhs_t) && is_numeric_type(rhs_t))
+ return overall_t;
+ code_err(ast, "I don't know how to do math operations between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t));
+ }
+ case Concat: {
+ binary_operands_t binop = BINARY_OPERANDS(ast);
+ type_t *lhs_t = get_type(env, binop.lhs);
+ type_t *rhs_t = get_type(env, binop.rhs);
+
+ type_t *overall_t = (can_promote(rhs_t, lhs_t) ? lhs_t : (can_promote(lhs_t, rhs_t) ? rhs_t : NULL));
+ if (overall_t == NULL)
+ code_err(ast, "I don't know how to do operations between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t));
+
+ binding_t *b = get_metamethod_binding(env, ast->tag, binop.lhs, binop.rhs, overall_t);
+ if (b) return overall_t;
+
+ if (overall_t->tag == ArrayType || overall_t->tag == SetType || overall_t->tag == TextType)
+ return overall_t;
+
+ code_err(ast, "I don't know how to do concatenation between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t));
}
case Reduction: {
auto reduction = Match(ast, Reduction);
type_t *iter_t = get_type(env, reduction->iter);
- if (reduction->op == BINOP_EQ || reduction->op == BINOP_NE || reduction->op == BINOP_LT
- || reduction->op == BINOP_LE || reduction->op == BINOP_GT || reduction->op == BINOP_GE)
+ if (reduction->op == Equals || reduction->op == NotEquals || reduction->op == LessThan
+ || reduction->op == LessThanOrEquals || reduction->op == GreaterThan || reduction->op == GreaterThanOrEquals)
return Type(OptionalType, .type=Type(BoolType));
type_t *iterated = get_iterated_type(iter_t);
@@ -1249,9 +1211,6 @@ type_t *get_type(env_t *env, ast_t *ast)
return iterated->tag == OptionalType ? iterated : Type(OptionalType, .type=iterated);
}
- case UpdateAssign:
- return Type(VoidType);
-
case Min: case Max: {
// Unsafe! These types *should* have the same fields and this saves a lot of duplicate code:
ast_t *lhs = ast->__data.Min.lhs, *rhs = ast->__data.Min.rhs;
@@ -1310,8 +1269,9 @@ type_t *get_type(env_t *env, ast_t *ast)
env_t *truthy_scope = env;
env_t *falsey_scope = env;
if (if_->condition->tag == Declare) {
- type_t *condition_type = get_type(env, Match(if_->condition, Declare)->value);
- const char *varname = Match(Match(if_->condition, Declare)->var, Var)->name;
+ auto decl = Match(if_->condition, Declare);
+ type_t *condition_type = decl->type ? parse_type_ast(env, decl->type) : get_type(env, decl->value);
+ const char *varname = Match(decl->var, Var)->name;
if (streq(varname, "_"))
code_err(if_->condition, "To use `if var := ...:`, you must choose a real variable name, not `_`");
@@ -1456,7 +1416,7 @@ type_t *get_type(env_t *env, ast_t *ast)
PUREFUNC bool is_discardable(env_t *env, ast_t *ast)
{
switch (ast->tag) {
- case UpdateAssign: case Assign: case Declare: case FunctionDef: case ConvertDef: case StructDef: case EnumDef:
+ case UPDATE_CASES: case Assign: case Declare: case FunctionDef: case ConvertDef: case StructDef: case EnumDef:
case LangDef: case Use: case Extend:
return true;
default: break;
@@ -1610,13 +1570,13 @@ PUREFUNC bool is_constant(env_t *env, ast_t *ast)
}
case Not: return is_constant(env, Match(ast, Not)->value);
case Negative: return is_constant(env, Match(ast, Negative)->value);
- case BinaryOp: {
- auto binop = Match(ast, BinaryOp);
- switch (binop->op) {
- case BINOP_UNKNOWN: case BINOP_POWER: case BINOP_CONCAT: case BINOP_MIN: case BINOP_MAX: case BINOP_CMP:
+ case BINOP_CASES: {
+ binary_operands_t binop = BINARY_OPERANDS(ast);
+ switch (ast->tag) {
+ case Power: case Concat: case Min: case Max: case Compare:
return false;
default:
- return is_constant(env, binop->lhs) && is_constant(env, binop->rhs);
+ return is_constant(env, binop.lhs) && is_constant(env, binop.rhs);
}
}
case Use: return true;
@@ -1626,4 +1586,49 @@ PUREFUNC bool is_constant(env_t *env, ast_t *ast)
}
}
+PUREFUNC bool can_compile_to_type(env_t *env, ast_t *ast, type_t *needed)
+{
+ if (needed->tag == OptionalType && ast->tag == None) {
+ return true;
+ }
+
+ needed = non_optional(needed);
+ if (needed->tag == ArrayType && ast->tag == Array) {
+ type_t *item_type = Match(needed, ArrayType)->item_type;
+ for (ast_list_t *item = Match(ast, Array)->items; item; item = item->next) {
+ if (!can_compile_to_type(env, item->ast, item_type))
+ return false;
+ }
+ return true;
+ } else if (needed->tag == SetType && ast->tag == Set) {
+ type_t *item_type = Match(needed, SetType)->item_type;
+ for (ast_list_t *item = Match(ast, Set)->items; item; item = item->next) {
+ if (!can_compile_to_type(env, item->ast, item_type))
+ return false;
+ }
+ return true;
+ } else if (needed->tag == TableType && ast->tag == Table) {
+ type_t *key_type = Match(needed, TableType)->key_type;
+ type_t *value_type = Match(needed, TableType)->value_type;
+ for (ast_list_t *entry = Match(ast, Table)->entries; entry; entry = entry->next) {
+ if (entry->ast->tag != TableEntry)
+ continue; // TODO: fix this
+ auto e = Match(entry->ast, TableEntry);
+ if (!can_compile_to_type(env, e->key, key_type) || !can_compile_to_type(env, e->value, value_type))
+ return false;
+ }
+ return true;
+ } else if (needed->tag == PointerType) {
+ auto ptr = Match(needed, PointerType);
+ if (ast->tag == HeapAllocate)
+ return !ptr->is_stack && can_compile_to_type(env, Match(ast, HeapAllocate)->value, ptr->pointed);
+ else if (ast->tag == StackReference)
+ return ptr->is_stack && can_compile_to_type(env, Match(ast, StackReference)->value, ptr->pointed);
+ else
+ return can_promote(needed, get_type(env, ast));
+ } else {
+ return can_promote(needed, get_type(env, ast));
+ }
+}
+
// vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0