diff options
Diffstat (limited to 'typecheck.c')
| -rw-r--r-- | typecheck.c | 77 |
1 files changed, 50 insertions, 27 deletions
diff --git a/typecheck.c b/typecheck.c index d8691544..2d1ba40a 100644 --- a/typecheck.c +++ b/typecheck.c @@ -811,11 +811,11 @@ type_t *get_type(env_t *env, ast_t *ast) if (t->tag == IntType || t->tag == NumType) return t; - binding_t *b = get_namespace_binding(env, value, "__negative"); + binding_t *b = get_namespace_binding(env, value, "negative"); if (b && b->type->tag == FunctionType) { auto fn = Match(b->type, FunctionType); - if (fn->args && can_promote(t, get_arg_type(env, fn->args))) - return fn->ret; + if (fn->args && type_eq(t, get_arg_type(env, fn->args)) && type_eq(t, fn->ret)) + return t; } code_err(ast, "I don't know how to get the negative value of type %T", t); @@ -826,6 +826,14 @@ type_t *get_type(env_t *env, ast_t *ast) return t; if (t->tag == PointerType && Match(t, PointerType)->is_optional) return Type(BoolType); + + ast_t *value = Match(ast, Not)->value; + binding_t *b = get_namespace_binding(env, value, "negated"); + if (b && b->type->tag == FunctionType) { + auto fn = Match(b->type, FunctionType); + if (fn->args && type_eq(t, get_arg_type(env, fn->args)) && type_eq(t, fn->ret)) + return t; + } code_err(ast, "I only know how to get 'not' of boolean, numeric, and optional pointer types, not %T", t); } case BinaryOp: { @@ -833,31 +841,46 @@ type_t *get_type(env_t *env, ast_t *ast) type_t *lhs_t = get_type(env, binop->lhs), *rhs_t = get_type(env, binop->rhs); - // Check for a binop method like __add() etc: - const char *method_name = binop_method_names[binop->op]; - if (method_name) { - for (int64_t n = 1; ; ) { - binding_t *b = get_namespace_binding(env, binop->lhs, method_name); - if (b && b->type->tag == FunctionType) { - auto fn = Match(b->type, FunctionType); - if (fn->args && fn->args->next && can_promote(lhs_t, get_arg_type(env, fn->args)) - && can_promote(rhs_t, get_arg_type(env, fn->args->next))) - return fn->ret; - } - binding_t *b2 = get_namespace_binding(env, binop->rhs, method_name); - if (b2 && b2->type->tag == FunctionType) { - auto fn = Match(b2->type, FunctionType); - if (fn->args && fn->args->next && can_promote(lhs_t, get_arg_type(env, fn->args)) - && can_promote(rhs_t, get_arg_type(env, fn->args->next))) - return fn->ret; - } - if (!b && !b2) break; - - // If we found __foo, but it didn't match the types, check for - // __foo2, __foo3, etc. until we stop finding methods with that name. - method_name = heap_strf("%s%ld", binop_method_names[binop->op], ++n); - } +#define binding_works(name, self, lhs_t, rhs_t, ret_t) \ + ({ binding_t *b = get_namespace_binding(env, self, name); \ + (b && b->type->tag == FunctionType && ({ auto fn = Match(b->type, FunctionType); \ + (type_eq(fn->ret, ret_t) \ + && (fn->args && type_eq(fn->args->type, lhs_t)) \ + && (fn->args->next && can_promote(fn->args->next->type, rhs_t))); })); }) + // Check for a binop method like plus() etc: + switch (binop->op) { + case BINOP_MULT: { + if ((lhs_t->tag == NumType || lhs_t->tag == IntType) && binding_works("scaled_by", binop->rhs, rhs_t, lhs_t, rhs_t)) + return rhs_t; + else if ((rhs_t->tag == NumType || rhs_t->tag == IntType) && binding_works("scaled_by", binop->lhs, lhs_t, rhs_t, lhs_t)) + return lhs_t; + else if (type_eq(lhs_t, rhs_t) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t)) + return lhs_t; + break; + } + case BINOP_PLUS: case BINOP_MINUS: case BINOP_AND: case BINOP_OR: case BINOP_XOR: { + if (type_eq(lhs_t, rhs_t) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t)) + return lhs_t; + break; + } + case BINOP_DIVIDE: case BINOP_MOD: case BINOP_MOD1: { + if ((rhs_t->tag == NumType || rhs_t->tag == IntType) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t)) + return lhs_t; + break; + } + case BINOP_LSHIFT: case BINOP_RSHIFT: { + if (rhs_t->tag == IntType && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t)) + return lhs_t; + break; + } + case BINOP_POWER: { + if (rhs_t->tag == NumType && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t)) + return lhs_t; + break; + } + default: break; } +#undef binding_works switch (binop->op) { case BINOP_AND: { |
