aboutsummaryrefslogtreecommitdiff
path: root/compile.c
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2024-08-10 20:50:15 -0400
committerBruce Hill <bruce@bruce-hill.com>2024-08-10 20:50:15 -0400
commit6d3d104363426d9d26a3fa65979899c032a093a7 (patch)
tree1d0353fc224d0d97c5f987c5087e8ac018c98d81 /compile.c
parentf0e56acc5b7930111ddf429f7186f0e72146517e (diff)
Overhaul of operator metamethods
Diffstat (limited to 'compile.c')
-rw-r--r--compile.c112
1 files changed, 69 insertions, 43 deletions
diff --git a/compile.c b/compile.c
index eedf6c98..d2081b4c 100644
--- a/compile.c
+++ b/compile.c
@@ -16,7 +16,7 @@
static CORD compile_to_pointer_depth(env_t *env, ast_t *ast, int64_t target_depth, bool allow_optional);
static env_t *with_enum_scope(env_t *env, type_t *t);
-static CORD compile_math_method(env_t *env, ast_t *ast, binop_e op, ast_t *lhs, ast_t *rhs, type_t *required_type);
+static CORD compile_math_method(env_t *env, binop_e op, ast_t *lhs, ast_t *rhs, type_t *required_type);
static CORD compile_string(env_t *env, ast_t *ast, CORD color);
static CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t *call_args);
@@ -445,7 +445,7 @@ CORD compile_statement(env_t *env, ast_t *ast)
auto update = Match(ast, UpdateAssign);
CORD lhs = compile_lvalue(env, update->lhs);
- CORD method_call = compile_math_method(env, ast, update->op, update->lhs, update->rhs, get_type(env, update->lhs));
+ CORD method_call = compile_math_method(env, update->op, update->lhs, update->rhs, get_type(env, update->lhs));
if (method_call)
return CORD_all(lhs, " = ", method_call, ";");
@@ -1237,9 +1237,9 @@ CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t
return code;
}
-CORD compile_math_method(env_t *env, ast_t *ast, binop_e op, ast_t *lhs, ast_t *rhs, type_t *required_type)
+CORD compile_math_method(env_t *env, binop_e op, ast_t *lhs, ast_t *rhs, type_t *required_type)
{
- // Math methods are things like __add(), __sub(), etc. If we don't find a
+ // Math methods are things like plus(), minus(), etc. If we don't find a
// matching method, return CORD_EMPTY.
const char *method_name = binop_method_names[op];
if (!method_name)
@@ -1247,36 +1247,62 @@ CORD compile_math_method(env_t *env, ast_t *ast, binop_e op, ast_t *lhs, ast_t *
type_t *lhs_t = get_type(env, lhs);
type_t *rhs_t = get_type(env, rhs);
- for (int64_t i = 1; ; ) {
- binding_t *b = get_namespace_binding(env, lhs, method_name);
- if (b && b->type->tag == FunctionType) {
- auto fn = Match(b->type, FunctionType);
- if (fn->args && fn->args->next && can_promote(lhs_t, get_arg_type(env, fn->args))
- && can_promote(rhs_t, get_arg_type(env, fn->args->next))
- && (!required_type || can_promote(fn->ret, required_type))) {
- return CORD_all(
- b->code, "(",
- compile_arguments(env, ast, fn->args, new(arg_ast_t, .value=lhs, .next=new(arg_ast_t, .value=rhs))),
- ")");
- }
+#define binding_works(b, lhs_t, rhs_t, ret_t) \
+ (b && b->type->tag == FunctionType && ({ auto fn = Match(b->type, FunctionType); \
+ (type_eq(fn->ret, ret_t) \
+ && (fn->args && type_eq(fn->args->type, lhs_t)) \
+ && (fn->args->next && can_promote(fn->args->next->type, rhs_t)) \
+ && (!required_type || type_eq(required_type, fn->ret))); }))
+ switch (op) {
+ case BINOP_MULT: {
+ if (lhs_t->tag == NumType || lhs_t->tag == IntType) {
+ binding_t *b = get_namespace_binding(env, rhs, "scaled_by");
+ if (binding_works(b, rhs_t, lhs_t, rhs_t))
+ return CORD_all(b->code, "(", compile(env, rhs), ", ", compile(env, lhs), ")");
+ } else if (rhs_t->tag == NumType || rhs_t->tag == IntType) {
+ binding_t *b = get_namespace_binding(env, lhs, "scaled_by");
+ if (binding_works(b, lhs_t, rhs_t, lhs_t))
+ return CORD_all(b->code, "(", compile(env, lhs), ", ", compile(env, rhs), ")");
+ } else if (type_eq(lhs_t, rhs_t)) {
+ binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]);
+ if (binding_works(b, lhs_t, rhs_t, lhs_t))
+ return CORD_all(b->code, "(", compile(env, lhs), ", ", compile(env, rhs), ")");
}
- binding_t *b2 = get_namespace_binding(env, rhs, method_name);
- if (b2 && b2->type->tag == FunctionType) {
- auto fn = Match(b2->type, FunctionType);
- if (fn->args && fn->args->next && can_promote(lhs_t, get_arg_type(env, fn->args))
- && can_promote(rhs_t, get_arg_type(env, fn->args->next))
- && (!required_type || can_promote(fn->ret, required_type))) {
- return CORD_all(
- b2->code, "(",
- compile_arguments(env, ast, fn->args, new(arg_ast_t, .value=lhs, .next=new(arg_ast_t, .value=rhs))),
- ")");
- }
+ break;
+ }
+ case BINOP_PLUS: case BINOP_MINUS: case BINOP_AND: case BINOP_OR: case BINOP_XOR: {
+ if (type_eq(lhs_t, rhs_t)) {
+ binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]);
+ if (binding_works(b, lhs_t, rhs_t, lhs_t))
+ return CORD_all(b->code, "(", compile(env, lhs), ", ", compile(env, rhs), ")");
}
- if (!b && !b2) break;
-
- // If we found __foo, but it didn't match the types, check for
- // __foo2, __foo3, etc. until we stop finding methods with that name.
- method_name = heap_strf("%s%ld", binop_method_names[op], ++i);
+ break;
+ }
+ case BINOP_DIVIDE: case BINOP_MOD: case BINOP_MOD1: {
+ if (rhs_t->tag == NumType || rhs_t->tag == IntType) {
+ binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]);
+ if (binding_works(b, lhs_t, rhs_t, lhs_t))
+ return CORD_all(b->code, "(", compile(env, lhs), ", ", compile(env, rhs), ")");
+ }
+ break;
+ }
+ case BINOP_LSHIFT: case BINOP_RSHIFT: {
+ if (rhs_t->tag == IntType) {
+ binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]);
+ if (binding_works(b, lhs_t, rhs_t, lhs_t))
+ return CORD_all(b->code, "(", compile(env, lhs), ", ", compile(env, rhs), ")");
+ }
+ break;
+ }
+ case BINOP_POWER: {
+ if (rhs_t->tag == NumType) {
+ binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]);
+ if (binding_works(b, lhs_t, rhs_t, lhs_t))
+ return CORD_all(b->code, "(", compile(env, lhs), ", ", compile(env, rhs), ")");
+ }
+ break;
+ }
+ default: break;
}
return CORD_EMPTY;
}
@@ -1327,13 +1353,6 @@ CORD compile(env_t *env, ast_t *ast)
}
}
default: {
- binding_t *b = get_namespace_binding(env, expr, "__length");
- if (b && b->type->tag == FunctionType) {
- auto fn = Match(b->type, FunctionType);
- if (type_eq(fn->ret, INT_TYPE) && fn->args && can_promote(t, get_arg_type(env, fn->args)))
- return CORD_all(b->code, "(", compile_arguments(env, ast, fn->args, new(arg_ast_t, .value=expr)), ")");
- }
-
code_err(ast, "Length is not implemented for %T values", t);
}
}
@@ -1350,8 +1369,15 @@ CORD compile(env_t *env, ast_t *ast)
return CORD_all("!(", compile(env, WrapAST(ast, Length, value)), ")");
else if (t->tag == TextType)
return CORD_all("!(", compile(env, value), ")");
- else
- code_err(ast, "I don't know how to negate values of type %T", t);
+
+ binding_t *b = get_namespace_binding(env, value, "negated");
+ if (b && b->type->tag == FunctionType) {
+ auto fn = Match(b->type, FunctionType);
+ if (fn->args && can_promote(t, get_arg_type(env, fn->args)))
+ return CORD_all(b->code, "(", compile_arguments(env, ast, fn->args, new(arg_ast_t, .value=value)), ")");
+ }
+
+ code_err(ast, "I don't know how to negate values of type %T", t);
}
case Negative: {
ast_t *value = Match(ast, Negative)->value;
@@ -1359,7 +1385,7 @@ CORD compile(env_t *env, ast_t *ast)
if (t->tag == IntType || t->tag == NumType)
return CORD_all("-(", compile(env, value), ")");
- binding_t *b = get_namespace_binding(env, value, "__negative");
+ binding_t *b = get_namespace_binding(env, value, "negative");
if (b && b->type->tag == FunctionType) {
auto fn = Match(b->type, FunctionType);
if (fn->args && can_promote(t, get_arg_type(env, fn->args)))
@@ -1382,7 +1408,7 @@ CORD compile(env_t *env, ast_t *ast)
}
case BinaryOp: {
auto binop = Match(ast, BinaryOp);
- CORD method_call = compile_math_method(env, ast, binop->op, binop->lhs, binop->rhs, NULL);
+ CORD method_call = compile_math_method(env, binop->op, binop->lhs, binop->rhs, NULL);
if (method_call != CORD_EMPTY)
return method_call;