diff options
Diffstat (limited to 'src/typecheck.c')
| -rw-r--r-- | src/typecheck.c | 539 |
1 files changed, 272 insertions, 267 deletions
diff --git a/src/typecheck.c b/src/typecheck.c index 8a2ee32b..cd6ff1c2 100644 --- a/src/typecheck.c +++ b/src/typecheck.c @@ -135,27 +135,27 @@ type_t *parse_type_ast(env_t *env, type_ast_t *ast) errx(1, "Unreachable"); } -static PUREFUNC bool risks_zero_or_inf(ast_t *ast) -{ - switch (ast->tag) { - case Int: { - const char *str = Match(ast, Int)->str; - OptionalInt_t int_val = Int$from_str(str); - return (int_val.small == 0x1); // zero - } - case Num: { - return Match(ast, Num)->n == 0.0; - } - case BinaryOp: { - auto binop = Match(ast, BinaryOp); - if (binop->op == BINOP_MULT || binop->op == BINOP_DIVIDE || binop->op == BINOP_MIN || binop->op == BINOP_MAX) - return risks_zero_or_inf(binop->lhs) || risks_zero_or_inf(binop->rhs); - else - return true; - } - default: return true; - } -} +// static PUREFUNC bool risks_zero_or_inf(ast_t *ast) +// { +// switch (ast->tag) { +// case Int: { +// const char *str = Match(ast, Int)->str; +// OptionalInt_t int_val = Int$from_str(str); +// return (int_val.small == 0x1); // zero +// } +// case Num: { +// return Match(ast, Num)->n == 0.0; +// } +// case BINOP_CASES: { +// binary_operands_t binop = BINARY_OPERANDS(ast); +// if (ast->tag == Multiply || ast->tag == Divide || ast->tag == Min || ast->tag == Max) +// return risks_zero_or_inf(binop.lhs) || risks_zero_or_inf(binop.rhs); +// else +// return true; +// } +// default: return true; +// } +// } PUREFUNC type_t *get_math_type(env_t *env, ast_t *ast, type_t *lhs_t, type_t *rhs_t) { @@ -312,7 +312,7 @@ void bind_statement(env_t *env, ast_t *statement) if (get_binding(env, name)) code_err(decl->var, "A ", type_to_str(get_binding(env, name)->type), " called ", quoted(name), " has already been defined"); bind_statement(env, decl->value); - type_t *type = get_type(env, decl->value); + type_t *type = decl->type ? parse_type_ast(env, decl->type) : get_type(env, decl->value); if (!type) code_err(decl->value, "I couldn't figure out the type of this value"); if (type->tag == FunctionType) @@ -617,12 +617,7 @@ type_t *get_type(env_t *env, ast_t *ast) #endif switch (ast->tag) { case None: { - if (!Match(ast, None)->type) - return Type(OptionalType, .type=NULL); - type_t *t = parse_type_ast(env, Match(ast, None)->type); - if (t->tag == OptionalType) - code_err(ast, "Nested optional types are not supported. This should be: `none:", type_to_str(Match(t, OptionalType)->type), "`"); - return Type(OptionalType, .type=t); + return Type(OptionalType, .type=NULL); } case Bool: { return Type(BoolType); @@ -714,115 +709,91 @@ type_t *get_type(env_t *env, ast_t *ast) case Array: { auto array = Match(ast, Array); type_t *item_type = NULL; - if (array->item_type) { - item_type = parse_type_ast(env, array->item_type); - } else if (array->items) { - for (ast_list_t *item = array->items; item; item = item->next) { - ast_t *item_ast = item->ast; - env_t *scope = env; - while (item_ast->tag == Comprehension) { - auto comp = Match(item_ast, Comprehension); - scope = for_scope( - scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars)); - item_ast = comp->expr; - } - type_t *t2 = get_type(scope, item_ast); - type_t *merged = item_type ? type_or_type(item_type, t2) : t2; - if (!merged) - code_err(item->ast, - "This array item has type ", type_to_str(t2), - ", which is different from earlier array items which have type ", type_to_str(item_type)); - item_type = merged; + for (ast_list_t *item = array->items; item; item = item->next) { + ast_t *item_ast = item->ast; + env_t *scope = env; + while (item_ast->tag == Comprehension) { + auto comp = Match(item_ast, Comprehension); + scope = for_scope( + scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars)); + item_ast = comp->expr; } - } else { - code_err(ast, "I can't figure out what type this array has because it has no members or explicit type"); + type_t *t2 = get_type(scope, item_ast); + type_t *merged = item_type ? type_or_type(item_type, t2) : t2; + if (!merged) + code_err(item->ast, + "This array item has type ", type_to_str(t2), + ", which is different from earlier array items which have type ", type_to_str(item_type)); + item_type = merged; } - if (has_stack_memory(item_type)) - code_err(ast, "Arrays cannot hold stack references, because the array may outlive the stack frame the reference was created in."); - if (!item_type) - code_err(ast, "I couldn't figure out the item type for this array!"); + if (item_type && has_stack_memory(item_type)) + code_err(ast, "Arrays cannot hold stack references, because the array may outlive the stack frame the reference was created in."); return Type(ArrayType, .item_type=item_type); } case Set: { auto set = Match(ast, Set); type_t *item_type = NULL; - if (set->item_type) { - item_type = parse_type_ast(env, set->item_type); - } else { - for (ast_list_t *item = set->items; item; item = item->next) { - ast_t *item_ast = item->ast; - env_t *scope = env; - while (item_ast->tag == Comprehension) { - auto comp = Match(item_ast, Comprehension); - scope = for_scope( - scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars)); - item_ast = comp->expr; - } - - type_t *this_item_type = get_type(scope, item_ast); - type_t *item_merged = type_or_type(item_type, this_item_type); - if (!item_merged) - code_err(item_ast, - "This set item has type ", type_to_str(this_item_type), - ", which is different from earlier set items which have type ", type_to_str(item_type)); - item_type = item_merged; + for (ast_list_t *item = set->items; item; item = item->next) { + ast_t *item_ast = item->ast; + env_t *scope = env; + while (item_ast->tag == Comprehension) { + auto comp = Match(item_ast, Comprehension); + scope = for_scope( + scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars)); + item_ast = comp->expr; } - } - if (!item_type) - code_err(ast, "I couldn't figure out the item type for this set!"); + type_t *this_item_type = get_type(scope, item_ast); + type_t *item_merged = type_or_type(item_type, this_item_type); + if (!item_merged) + code_err(item_ast, + "This set item has type ", type_to_str(this_item_type), + ", which is different from earlier set items which have type ", type_to_str(item_type)); + item_type = item_merged; + } - if (has_stack_memory(item_type)) + if (item_type && has_stack_memory(item_type)) code_err(ast, "Sets cannot hold stack references because the set may outlive the reference's stack frame."); + return Type(SetType, .item_type=item_type); } case Table: { auto table = Match(ast, Table); type_t *key_type = NULL, *value_type = NULL; - if (table->key_type && table->value_type) { - key_type = parse_type_ast(env, table->key_type); - value_type = parse_type_ast(env, table->value_type); - } else if (table->key_type && table->default_value) { - key_type = parse_type_ast(env, table->key_type); - value_type = get_type(env, table->default_value); - } else { - for (ast_list_t *entry = table->entries; entry; entry = entry->next) { - ast_t *entry_ast = entry->ast; - env_t *scope = env; - while (entry_ast->tag == Comprehension) { - auto comp = Match(entry_ast, Comprehension); - scope = for_scope( - scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars)); - entry_ast = comp->expr; - } - - auto e = Match(entry_ast, TableEntry); - type_t *key_t = get_type(scope, e->key); - type_t *value_t = get_type(scope, e->value); - - type_t *key_merged = key_type ? type_or_type(key_type, key_t) : key_t; - if (!key_merged) - code_err(entry->ast, - "This table entry has type ", type_to_str(key_t), - ", which is different from earlier table entries which have type ", type_to_str(key_type)); - key_type = key_merged; - - type_t *val_merged = value_type ? type_or_type(value_type, value_t) : value_t; - if (!val_merged) - code_err(entry->ast, - "This table entry has type ", type_to_str(value_t), - ", which is different from earlier table entries which have type ", type_to_str(value_type)); - value_type = val_merged; + for (ast_list_t *entry = table->entries; entry; entry = entry->next) { + ast_t *entry_ast = entry->ast; + env_t *scope = env; + while (entry_ast->tag == Comprehension) { + auto comp = Match(entry_ast, Comprehension); + scope = for_scope( + scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars)); + entry_ast = comp->expr; } - } - if (!key_type || !value_type) - code_err(ast, "I couldn't figure out the key and value types for this table!"); + auto e = Match(entry_ast, TableEntry); + type_t *key_t = get_type(scope, e->key); + type_t *value_t = get_type(scope, e->value); + + type_t *key_merged = key_type ? type_or_type(key_type, key_t) : key_t; + if (!key_merged) + code_err(entry->ast, + "This table entry has type ", type_to_str(key_t), + ", which is different from earlier table entries which have type ", type_to_str(key_type)); + key_type = key_merged; - if (has_stack_memory(key_type) || has_stack_memory(value_type)) + type_t *val_merged = value_type ? type_or_type(value_type, value_t) : value_t; + if (!val_merged) + code_err(entry->ast, + "This table entry has type ", type_to_str(value_t), + ", which is different from earlier table entries which have type ", type_to_str(value_type)); + value_type = val_merged; + } + + if ((key_type && has_stack_memory(key_type)) || (value_type && has_stack_memory(value_type))) code_err(ast, "Tables cannot hold stack references because the table may outlive the reference's stack frame."); + return Type(TableType, .key_type=key_type, .value_type=value_type, .default_value=table->default_value, .env=env); } case TableEntry: { @@ -998,7 +969,7 @@ type_t *get_type(env_t *env, ast_t *ast) // Early out if the type is knowable without any context from the block: switch (last->ast->tag) { - case UpdateAssign: case Assign: case Declare: case FunctionDef: case ConvertDef: case StructDef: case EnumDef: case LangDef: case Extend: + case UPDATE_CASES: case Assign: case Declare: case FunctionDef: case ConvertDef: case StructDef: case EnumDef: case LangDef: case Extend: return Type(VoidType); default: break; } @@ -1022,7 +993,7 @@ type_t *get_type(env_t *env, ast_t *ast) case Extern: { return parse_type_ast(env, Match(ast, Extern)->type); } - case Declare: case Assign: case DocTest: { + case Declare: case Assign: case UPDATE_CASES: case DocTest: { return Type(VoidType); } case Use: { @@ -1078,169 +1049,160 @@ type_t *get_type(env_t *env, ast_t *ast) } code_err(ast, "I only know how to get 'not' of boolean, numeric, and optional pointer types, not ", type_to_str(t)); } - case BinaryOp: { - auto binop = Match(ast, BinaryOp); - type_t *lhs_t = get_type(env, binop->lhs), - *rhs_t = get_type(env, binop->rhs); - - if (lhs_t->tag == BigIntType && rhs_t->tag != BigIntType && is_numeric_type(rhs_t) && binop->lhs->tag == Int) { - lhs_t = rhs_t; - } else if (rhs_t->tag == BigIntType && lhs_t->tag != BigIntType && is_numeric_type(lhs_t) && binop->rhs->tag == Int) { + case Or: { + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); - rhs_t = lhs_t; + // `opt? or (x == y)` / `(x == y) or opt?` is a boolean conditional: + if ((lhs_t->tag == OptionalType && rhs_t->tag == BoolType) + || (lhs_t->tag == BoolType && rhs_t->tag == OptionalType)) { + return Type(BoolType); } -#define binding_works(name, self, lhs_t, rhs_t, ret_t) \ - ({ binding_t *b = get_namespace_binding(env, self, name); \ - (b && b->type->tag == FunctionType && ({ auto fn = Match(b->type, FunctionType); \ - (type_eq(fn->ret, ret_t) \ - && (fn->args && type_eq(fn->args->type, lhs_t)) \ - && (fn->args->next && can_promote(rhs_t, fn->args->next->type))); })); }) - // Check for a binop method like plus() etc: - switch (binop->op) { - case BINOP_MULT: { - if (is_numeric_type(lhs_t) && binding_works("scaled_by", binop->rhs, rhs_t, lhs_t, rhs_t)) - return rhs_t; - else if (is_numeric_type(rhs_t) && binding_works("scaled_by", binop->lhs, lhs_t, rhs_t, lhs_t)) - return lhs_t; - else if (type_eq(lhs_t, rhs_t) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t)) + if (lhs_t->tag == OptionalType) { + if (rhs_t->tag == OptionalType) { + type_t *result = most_complete_type(lhs_t, rhs_t); + if (result == NULL) + code_err(ast, "I could not determine the type of ", type_to_str(lhs_t), " `or` ", type_to_str(rhs_t)); + return result; + } else if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) { + return Match(lhs_t, OptionalType)->type; + } + type_t *non_opt = Match(lhs_t, OptionalType)->type; + non_opt = most_complete_type(non_opt, rhs_t); + if (non_opt != NULL) + return non_opt; + } else if ((is_numeric_type(lhs_t) || lhs_t->tag == BoolType) + && (is_numeric_type(rhs_t) || rhs_t->tag == BoolType) + && lhs_t->tag != NumType && rhs_t->tag != NumType) { + if (can_promote(rhs_t, lhs_t)) return lhs_t; - break; + else if (can_promote(lhs_t, rhs_t)) + return rhs_t; + } else if (lhs_t->tag == SetType && rhs_t->tag == SetType && type_eq(lhs_t, rhs_t)) { + return lhs_t; } - case BINOP_PLUS: case BINOP_MINUS: case BINOP_AND: case BINOP_OR: case BINOP_XOR: case BINOP_CONCAT: { - if (type_eq(lhs_t, rhs_t) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t)) - return lhs_t; - break; + code_err(ast, "I couldn't figure out how to do `or` between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + } + case And: { + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); + + // `and` between optionals/bools is a boolean expression like `if opt? and opt?:` or `if x > 0 and opt?:` + if ((lhs_t->tag == OptionalType || lhs_t->tag == BoolType) + && (rhs_t->tag == OptionalType || rhs_t->tag == BoolType)) { + return Type(BoolType); } - case BINOP_DIVIDE: case BINOP_MOD: case BINOP_MOD1: { - if (is_numeric_type(rhs_t) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t)) + + // Bitwise AND: + if ((is_numeric_type(lhs_t) || lhs_t->tag == BoolType) + && (is_numeric_type(rhs_t) || rhs_t->tag == BoolType) + && lhs_t->tag != NumType && rhs_t->tag != NumType) { + if (can_promote(rhs_t, lhs_t)) return lhs_t; - break; - } - case BINOP_LSHIFT: case BINOP_RSHIFT: case BINOP_ULSHIFT: case BINOP_URSHIFT: { + else if (can_promote(lhs_t, rhs_t)) + return rhs_t; + } else if (lhs_t->tag == SetType && rhs_t->tag == SetType && type_eq(lhs_t, rhs_t)) { return lhs_t; } - case BINOP_POWER: { - if (is_numeric_type(rhs_t) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t)) - return lhs_t; - break; - } - default: break; - } -#undef binding_works + code_err(ast, "I couldn't figure out how to do `and` between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + } + case Xor: { + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); - switch (binop->op) { - case BINOP_AND: { - if (lhs_t->tag == BoolType && rhs_t->tag == BoolType) { - return lhs_t; - } else if ((lhs_t->tag == BoolType && rhs_t->tag == OptionalType) || - (lhs_t->tag == OptionalType && rhs_t->tag == BoolType)) { - return Type(BoolType); - } else if (lhs_t->tag == BoolType && (rhs_t->tag == AbortType || rhs_t->tag == ReturnType)) { - return lhs_t; - } else if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) { - return lhs_t; - } else if (rhs_t->tag == OptionalType) { - if (can_promote(lhs_t, rhs_t)) - return rhs_t; - } else if (lhs_t->tag == PointerType && rhs_t->tag == PointerType) { - auto lhs_ptr = Match(lhs_t, PointerType); - auto rhs_ptr = Match(rhs_t, PointerType); - if (type_eq(lhs_ptr->pointed, rhs_ptr->pointed)) - return Type(PointerType, .pointed=lhs_ptr->pointed); - } else if ((is_int_type(lhs_t) && is_int_type(rhs_t)) - || (lhs_t->tag == ByteType && rhs_t->tag == ByteType)) { - return get_math_type(env, ast, lhs_t, rhs_t); - } - code_err(ast, "I can't figure out the type of this `and` expression between a ", type_to_str(lhs_t), " and a ", type_to_str(rhs_t)); + // `xor` between optionals/bools is a boolean expression like `if opt? xor opt?:` or `if x > 0 xor opt?:` + if ((lhs_t->tag == OptionalType || lhs_t->tag == BoolType) + && (rhs_t->tag == OptionalType || rhs_t->tag == BoolType)) { + return Type(BoolType); } - case BINOP_OR: { - if (lhs_t->tag == BoolType && rhs_t->tag == BoolType) { - return lhs_t; - } else if ((lhs_t->tag == BoolType && rhs_t->tag == OptionalType) || - (lhs_t->tag == OptionalType && rhs_t->tag == BoolType)) { - return Type(BoolType); - } else if (lhs_t->tag == BoolType && (rhs_t->tag == AbortType || rhs_t->tag == ReturnType)) { + + // Bitwise XOR: + if ((is_numeric_type(lhs_t) || lhs_t->tag == BoolType) + && (is_numeric_type(rhs_t) || rhs_t->tag == BoolType) + && lhs_t->tag != NumType && rhs_t->tag != NumType) { + if (can_promote(rhs_t, lhs_t)) return lhs_t; - } else if ((is_int_type(lhs_t) && is_int_type(rhs_t)) - || (lhs_t->tag == ByteType && rhs_t->tag == ByteType)) { - return get_math_type(env, ast, lhs_t, rhs_t); - } else if (lhs_t->tag == OptionalType) { - if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) - return Match(lhs_t, OptionalType)->type; - if (can_promote(rhs_t, lhs_t)) - return rhs_t; - } else if (lhs_t->tag == PointerType) { - auto lhs_ptr = Match(lhs_t, PointerType); - if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) { - return Type(PointerType, .pointed=lhs_ptr->pointed); - } else if (rhs_t->tag == PointerType) { - auto rhs_ptr = Match(rhs_t, PointerType); - if (type_eq(rhs_ptr->pointed, lhs_ptr->pointed)) - return Type(PointerType, .pointed=lhs_ptr->pointed); - } - } else if (rhs_t->tag == OptionalType) { - return type_or_type(lhs_t, rhs_t); - } - code_err(ast, "I can't figure out the type of this `or` expression between a ", type_to_str(lhs_t), " and a ", type_to_str(rhs_t)); + else if (can_promote(lhs_t, rhs_t)) + return rhs_t; + } else if (lhs_t->tag == SetType && rhs_t->tag == SetType && type_eq(lhs_t, rhs_t)) { + return lhs_t; } - case BINOP_XOR: { - if (lhs_t->tag == BoolType && rhs_t->tag == BoolType) { - return lhs_t; - } else if ((lhs_t->tag == BoolType && rhs_t->tag == OptionalType) || - (lhs_t->tag == OptionalType && rhs_t->tag == BoolType)) { - return Type(BoolType); - } else if ((is_int_type(lhs_t) && is_int_type(rhs_t)) - || (lhs_t->tag == ByteType && rhs_t->tag == ByteType)) { - return get_math_type(env, ast, lhs_t, rhs_t); - } + code_err(ast, "I couldn't figure out how to do `xor` between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + } + case Compare: { + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); - code_err(ast, "I can't figure out the type of this `xor` expression between a ", type_to_str(lhs_t), " and a ", type_to_str(rhs_t)); - } - case BINOP_CONCAT: { - if (!type_eq(lhs_t, rhs_t)) - code_err(ast, "The type on the left side of this concatenation doesn't match the right side: ", type_to_str(lhs_t), - " vs. ", type_to_str(rhs_t)); - if (lhs_t->tag == ArrayType || lhs_t->tag == TextType || lhs_t->tag == SetType) - return lhs_t; + if (can_promote(rhs_t, lhs_t) || can_promote(lhs_t, rhs_t)) + return Type(IntType, .bits=TYPE_IBITS32); - code_err(ast, "Only array/set/text value types support concatenation, not ", type_to_str(lhs_t)); - } - case BINOP_EQ: case BINOP_NE: case BINOP_LT: case BINOP_LE: case BINOP_GT: case BINOP_GE: { - if (!can_promote(lhs_t, rhs_t) && !can_promote(rhs_t, lhs_t)) - code_err(ast, "I can't compare these two different types: ", type_to_str(lhs_t), " vs ", type_to_str(rhs_t)); + code_err(ast, "I don't know how to compare ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + } + case Equals: case NotEquals: case LessThan: case LessThanOrEquals: case GreaterThan: case GreaterThanOrEquals: { + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); + if (can_promote(rhs_t, lhs_t) || can_promote(lhs_t, rhs_t)) return Type(BoolType); + + code_err(ast, "I don't know how to compare ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + } + case Power: case Multiply: case Divide: case Mod: case Mod1: case Plus: case Minus: case LeftShift: + case UnsignedLeftShift: case RightShift: case UnsignedRightShift: { + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); + + if (ast->tag == LeftShift || ast->tag == UnsignedLeftShift || ast->tag == RightShift || ast->tag == UnsignedRightShift) { + if (!is_int_type(rhs_t)) + code_err(binop.rhs, "I only know how to do bit shifting by integer amounts, not ", type_to_str(rhs_t)); } - case BINOP_CMP: - return Type(IntType, .bits=TYPE_IBITS32); - case BINOP_POWER: { - type_t *result = get_math_type(env, ast, lhs_t, rhs_t); - if (result->tag == NumType) - return result; - return Type(NumType, .bits=TYPE_NBITS64); - } - case BINOP_MULT: case BINOP_DIVIDE: { - type_t *math_type = get_math_type(env, ast, value_type(lhs_t), value_type(rhs_t)); - if (value_type(lhs_t)->tag == NumType || value_type(rhs_t)->tag == NumType) { - if (risks_zero_or_inf(binop->lhs) && risks_zero_or_inf(binop->rhs)) - return Type(OptionalType, math_type); - else - return math_type; - } - return math_type; - } - default: { - return get_math_type(env, ast, lhs_t, rhs_t); - } + + type_t *overall_t = (can_promote(rhs_t, lhs_t) ? lhs_t : (can_promote(lhs_t, rhs_t) ? rhs_t : NULL)); + if (ast->tag == Multiply || ast->tag == Divide) { + binding_t *b = is_numeric_type(lhs_t) ? get_metamethod_binding(env, ast->tag, binop.lhs, binop.rhs, lhs_t) + : get_metamethod_binding(env, ast->tag, binop.rhs, binop.lhs, rhs_t); + if (b) return overall_t; + } else { + if (overall_t == NULL) + code_err(ast, "I don't know how to do math operations between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + + binding_t *b = get_metamethod_binding(env, ast->tag, binop.lhs, binop.rhs, overall_t); + if (b) return overall_t; } + if (is_numeric_type(lhs_t) && is_numeric_type(rhs_t)) + return overall_t; + code_err(ast, "I don't know how to do math operations between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + } + case Concat: { + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); + + type_t *overall_t = (can_promote(rhs_t, lhs_t) ? lhs_t : (can_promote(lhs_t, rhs_t) ? rhs_t : NULL)); + if (overall_t == NULL) + code_err(ast, "I don't know how to do operations between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + + binding_t *b = get_metamethod_binding(env, ast->tag, binop.lhs, binop.rhs, overall_t); + if (b) return overall_t; + + if (overall_t->tag == ArrayType || overall_t->tag == SetType || overall_t->tag == TextType) + return overall_t; + + code_err(ast, "I don't know how to do concatenation between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); } case Reduction: { auto reduction = Match(ast, Reduction); type_t *iter_t = get_type(env, reduction->iter); - if (reduction->op == BINOP_EQ || reduction->op == BINOP_NE || reduction->op == BINOP_LT - || reduction->op == BINOP_LE || reduction->op == BINOP_GT || reduction->op == BINOP_GE) + if (reduction->op == Equals || reduction->op == NotEquals || reduction->op == LessThan + || reduction->op == LessThanOrEquals || reduction->op == GreaterThan || reduction->op == GreaterThanOrEquals) return Type(OptionalType, .type=Type(BoolType)); type_t *iterated = get_iterated_type(iter_t); @@ -1249,9 +1211,6 @@ type_t *get_type(env_t *env, ast_t *ast) return iterated->tag == OptionalType ? iterated : Type(OptionalType, .type=iterated); } - case UpdateAssign: - return Type(VoidType); - case Min: case Max: { // Unsafe! These types *should* have the same fields and this saves a lot of duplicate code: ast_t *lhs = ast->__data.Min.lhs, *rhs = ast->__data.Min.rhs; @@ -1310,8 +1269,9 @@ type_t *get_type(env_t *env, ast_t *ast) env_t *truthy_scope = env; env_t *falsey_scope = env; if (if_->condition->tag == Declare) { - type_t *condition_type = get_type(env, Match(if_->condition, Declare)->value); - const char *varname = Match(Match(if_->condition, Declare)->var, Var)->name; + auto decl = Match(if_->condition, Declare); + type_t *condition_type = decl->type ? parse_type_ast(env, decl->type) : get_type(env, decl->value); + const char *varname = Match(decl->var, Var)->name; if (streq(varname, "_")) code_err(if_->condition, "To use `if var := ...:`, you must choose a real variable name, not `_`"); @@ -1456,7 +1416,7 @@ type_t *get_type(env_t *env, ast_t *ast) PUREFUNC bool is_discardable(env_t *env, ast_t *ast) { switch (ast->tag) { - case UpdateAssign: case Assign: case Declare: case FunctionDef: case ConvertDef: case StructDef: case EnumDef: + case UPDATE_CASES: case Assign: case Declare: case FunctionDef: case ConvertDef: case StructDef: case EnumDef: case LangDef: case Use: case Extend: return true; default: break; @@ -1610,13 +1570,13 @@ PUREFUNC bool is_constant(env_t *env, ast_t *ast) } case Not: return is_constant(env, Match(ast, Not)->value); case Negative: return is_constant(env, Match(ast, Negative)->value); - case BinaryOp: { - auto binop = Match(ast, BinaryOp); - switch (binop->op) { - case BINOP_UNKNOWN: case BINOP_POWER: case BINOP_CONCAT: case BINOP_MIN: case BINOP_MAX: case BINOP_CMP: + case BINOP_CASES: { + binary_operands_t binop = BINARY_OPERANDS(ast); + switch (ast->tag) { + case Power: case Concat: case Min: case Max: case Compare: return false; default: - return is_constant(env, binop->lhs) && is_constant(env, binop->rhs); + return is_constant(env, binop.lhs) && is_constant(env, binop.rhs); } } case Use: return true; @@ -1626,4 +1586,49 @@ PUREFUNC bool is_constant(env_t *env, ast_t *ast) } } +PUREFUNC bool can_compile_to_type(env_t *env, ast_t *ast, type_t *needed) +{ + if (needed->tag == OptionalType && ast->tag == None) { + return true; + } + + needed = non_optional(needed); + if (needed->tag == ArrayType && ast->tag == Array) { + type_t *item_type = Match(needed, ArrayType)->item_type; + for (ast_list_t *item = Match(ast, Array)->items; item; item = item->next) { + if (!can_compile_to_type(env, item->ast, item_type)) + return false; + } + return true; + } else if (needed->tag == SetType && ast->tag == Set) { + type_t *item_type = Match(needed, SetType)->item_type; + for (ast_list_t *item = Match(ast, Set)->items; item; item = item->next) { + if (!can_compile_to_type(env, item->ast, item_type)) + return false; + } + return true; + } else if (needed->tag == TableType && ast->tag == Table) { + type_t *key_type = Match(needed, TableType)->key_type; + type_t *value_type = Match(needed, TableType)->value_type; + for (ast_list_t *entry = Match(ast, Table)->entries; entry; entry = entry->next) { + if (entry->ast->tag != TableEntry) + continue; // TODO: fix this + auto e = Match(entry->ast, TableEntry); + if (!can_compile_to_type(env, e->key, key_type) || !can_compile_to_type(env, e->value, value_type)) + return false; + } + return true; + } else if (needed->tag == PointerType) { + auto ptr = Match(needed, PointerType); + if (ast->tag == HeapAllocate) + return !ptr->is_stack && can_compile_to_type(env, Match(ast, HeapAllocate)->value, ptr->pointed); + else if (ast->tag == StackReference) + return ptr->is_stack && can_compile_to_type(env, Match(ast, StackReference)->value, ptr->pointed); + else + return can_promote(needed, get_type(env, ast)); + } else { + return can_promote(needed, get_type(env, ast)); + } +} + // vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0 |
