aboutsummaryrefslogtreecommitdiff
path: root/typecheck.c
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2024-08-10 20:50:15 -0400
committerBruce Hill <bruce@bruce-hill.com>2024-08-10 20:50:15 -0400
commit6d3d104363426d9d26a3fa65979899c032a093a7 (patch)
tree1d0353fc224d0d97c5f987c5087e8ac018c98d81 /typecheck.c
parentf0e56acc5b7930111ddf429f7186f0e72146517e (diff)
Overhaul of operator metamethods
Diffstat (limited to 'typecheck.c')
-rw-r--r--typecheck.c77
1 files changed, 50 insertions, 27 deletions
diff --git a/typecheck.c b/typecheck.c
index d8691544..2d1ba40a 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -811,11 +811,11 @@ type_t *get_type(env_t *env, ast_t *ast)
if (t->tag == IntType || t->tag == NumType)
return t;
- binding_t *b = get_namespace_binding(env, value, "__negative");
+ binding_t *b = get_namespace_binding(env, value, "negative");
if (b && b->type->tag == FunctionType) {
auto fn = Match(b->type, FunctionType);
- if (fn->args && can_promote(t, get_arg_type(env, fn->args)))
- return fn->ret;
+ if (fn->args && type_eq(t, get_arg_type(env, fn->args)) && type_eq(t, fn->ret))
+ return t;
}
code_err(ast, "I don't know how to get the negative value of type %T", t);
@@ -826,6 +826,14 @@ type_t *get_type(env_t *env, ast_t *ast)
return t;
if (t->tag == PointerType && Match(t, PointerType)->is_optional)
return Type(BoolType);
+
+ ast_t *value = Match(ast, Not)->value;
+ binding_t *b = get_namespace_binding(env, value, "negated");
+ if (b && b->type->tag == FunctionType) {
+ auto fn = Match(b->type, FunctionType);
+ if (fn->args && type_eq(t, get_arg_type(env, fn->args)) && type_eq(t, fn->ret))
+ return t;
+ }
code_err(ast, "I only know how to get 'not' of boolean, numeric, and optional pointer types, not %T", t);
}
case BinaryOp: {
@@ -833,31 +841,46 @@ type_t *get_type(env_t *env, ast_t *ast)
type_t *lhs_t = get_type(env, binop->lhs),
*rhs_t = get_type(env, binop->rhs);
- // Check for a binop method like __add() etc:
- const char *method_name = binop_method_names[binop->op];
- if (method_name) {
- for (int64_t n = 1; ; ) {
- binding_t *b = get_namespace_binding(env, binop->lhs, method_name);
- if (b && b->type->tag == FunctionType) {
- auto fn = Match(b->type, FunctionType);
- if (fn->args && fn->args->next && can_promote(lhs_t, get_arg_type(env, fn->args))
- && can_promote(rhs_t, get_arg_type(env, fn->args->next)))
- return fn->ret;
- }
- binding_t *b2 = get_namespace_binding(env, binop->rhs, method_name);
- if (b2 && b2->type->tag == FunctionType) {
- auto fn = Match(b2->type, FunctionType);
- if (fn->args && fn->args->next && can_promote(lhs_t, get_arg_type(env, fn->args))
- && can_promote(rhs_t, get_arg_type(env, fn->args->next)))
- return fn->ret;
- }
- if (!b && !b2) break;
-
- // If we found __foo, but it didn't match the types, check for
- // __foo2, __foo3, etc. until we stop finding methods with that name.
- method_name = heap_strf("%s%ld", binop_method_names[binop->op], ++n);
- }
+#define binding_works(name, self, lhs_t, rhs_t, ret_t) \
+ ({ binding_t *b = get_namespace_binding(env, self, name); \
+ (b && b->type->tag == FunctionType && ({ auto fn = Match(b->type, FunctionType); \
+ (type_eq(fn->ret, ret_t) \
+ && (fn->args && type_eq(fn->args->type, lhs_t)) \
+ && (fn->args->next && can_promote(fn->args->next->type, rhs_t))); })); })
+ // Check for a binop method like plus() etc:
+ switch (binop->op) {
+ case BINOP_MULT: {
+ if ((lhs_t->tag == NumType || lhs_t->tag == IntType) && binding_works("scaled_by", binop->rhs, rhs_t, lhs_t, rhs_t))
+ return rhs_t;
+ else if ((rhs_t->tag == NumType || rhs_t->tag == IntType) && binding_works("scaled_by", binop->lhs, lhs_t, rhs_t, lhs_t))
+ return lhs_t;
+ else if (type_eq(lhs_t, rhs_t) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t))
+ return lhs_t;
+ break;
+ }
+ case BINOP_PLUS: case BINOP_MINUS: case BINOP_AND: case BINOP_OR: case BINOP_XOR: {
+ if (type_eq(lhs_t, rhs_t) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t))
+ return lhs_t;
+ break;
+ }
+ case BINOP_DIVIDE: case BINOP_MOD: case BINOP_MOD1: {
+ if ((rhs_t->tag == NumType || rhs_t->tag == IntType) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t))
+ return lhs_t;
+ break;
+ }
+ case BINOP_LSHIFT: case BINOP_RSHIFT: {
+ if (rhs_t->tag == IntType && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t))
+ return lhs_t;
+ break;
+ }
+ case BINOP_POWER: {
+ if (rhs_t->tag == NumType && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t))
+ return lhs_t;
+ break;
+ }
+ default: break;
}
+#undef binding_works
switch (binop->op) {
case BINOP_AND: {