diff --git a/src/compile.c b/src/compile.c index 33621bd..05cc7bc 100644 --- a/src/compile.c +++ b/src/compile.c @@ -606,6 +606,16 @@ static CORD compile_binary_op(env_t *env, ast_t *ast) return CORD_all(b->code, "(", compile_arguments(env, ast, fn->args, args), ")"); } } + } else if ((ast->tag == Divide || ast->tag == Mod || ast->tag == Mod1) && is_numeric_type(rhs_t)) { + b = get_namespace_binding(env, binop.lhs, binop_method_name(ast->tag)); + if (b && b->type->tag == FunctionType) { + auto fn = Match(b->type, FunctionType); + if (type_eq(fn->ret, lhs_t)) { + arg_ast_t *args = new(arg_ast_t, .value=binop.lhs, .next=new(arg_ast_t, .value=binop.rhs)); + if (is_valid_call(env, fn->args, args, true)) + return CORD_all(b->code, "(", compile_arguments(env, ast, fn->args, args), ")"); + } + } } if (ast->tag == Or && lhs_t->tag == OptionalType) { diff --git a/src/typecheck.c b/src/typecheck.c index 0291b74..e417650 100644 --- a/src/typecheck.c +++ b/src/typecheck.c @@ -1060,6 +1060,11 @@ type_t *get_type(env_t *env, ast_t *ast) return Type(BoolType); } + if (type_eq(lhs_t, rhs_t)) { + binding_t *b = get_metamethod_binding(env, ast->tag, binop.lhs, binop.rhs, lhs_t); + if (b) return lhs_t; + } + if (lhs_t->tag == OptionalType) { if (rhs_t->tag == OptionalType) { type_t *result = most_complete_type(lhs_t, rhs_t); @@ -1096,6 +1101,11 @@ type_t *get_type(env_t *env, ast_t *ast) return Type(BoolType); } + if (type_eq(lhs_t, rhs_t)) { + binding_t *b = get_metamethod_binding(env, ast->tag, binop.lhs, binop.rhs, lhs_t); + if (b) return lhs_t; + } + // Bitwise AND: if ((is_numeric_type(lhs_t) || lhs_t->tag == BoolType) && (is_numeric_type(rhs_t) || rhs_t->tag == BoolType) @@ -1120,6 +1130,11 @@ type_t *get_type(env_t *env, ast_t *ast) return Type(BoolType); } + if (type_eq(lhs_t, rhs_t)) { + binding_t *b = get_metamethod_binding(env, ast->tag, binop.lhs, binop.rhs, lhs_t); + if (b) return lhs_t; + } + // Bitwise XOR: if ((is_numeric_type(lhs_t) || lhs_t->tag == BoolType) && (is_numeric_type(rhs_t) || rhs_t->tag == BoolType) @@ -1203,8 +1218,8 @@ type_t *get_type(env_t *env, ast_t *ast) return lhs_t; } } - } else if (ast->tag == Divide && is_numeric_type(rhs_t)) { - binding_t *b = get_namespace_binding(env, binop.lhs, "divided_by"); + } else if ((ast->tag == Divide || ast->tag == Mod || ast->tag == Mod1) && is_numeric_type(rhs_t)) { + binding_t *b = get_namespace_binding(env, binop.lhs, binop_method_name(ast->tag)); if (b && b->type->tag == FunctionType) { auto fn = Match(b->type, FunctionType); if (type_eq(fn->ret, lhs_t)) {