diff options
| author | Bruce Hill <bruce@bruce-hill.com> | 2024-08-10 20:50:15 -0400 |
|---|---|---|
| committer | Bruce Hill <bruce@bruce-hill.com> | 2024-08-10 20:50:15 -0400 |
| commit | 6d3d104363426d9d26a3fa65979899c032a093a7 (patch) | |
| tree | 1d0353fc224d0d97c5f987c5087e8ac018c98d81 /compile.c | |
| parent | f0e56acc5b7930111ddf429f7186f0e72146517e (diff) | |
Overhaul of operator metamethods
Diffstat (limited to 'compile.c')
| -rw-r--r-- | compile.c | 112 |
1 files changed, 69 insertions, 43 deletions
@@ -16,7 +16,7 @@ static CORD compile_to_pointer_depth(env_t *env, ast_t *ast, int64_t target_depth, bool allow_optional); static env_t *with_enum_scope(env_t *env, type_t *t); -static CORD compile_math_method(env_t *env, ast_t *ast, binop_e op, ast_t *lhs, ast_t *rhs, type_t *required_type); +static CORD compile_math_method(env_t *env, binop_e op, ast_t *lhs, ast_t *rhs, type_t *required_type); static CORD compile_string(env_t *env, ast_t *ast, CORD color); static CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t *call_args); @@ -445,7 +445,7 @@ CORD compile_statement(env_t *env, ast_t *ast) auto update = Match(ast, UpdateAssign); CORD lhs = compile_lvalue(env, update->lhs); - CORD method_call = compile_math_method(env, ast, update->op, update->lhs, update->rhs, get_type(env, update->lhs)); + CORD method_call = compile_math_method(env, update->op, update->lhs, update->rhs, get_type(env, update->lhs)); if (method_call) return CORD_all(lhs, " = ", method_call, ";"); @@ -1237,9 +1237,9 @@ CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t return code; } -CORD compile_math_method(env_t *env, ast_t *ast, binop_e op, ast_t *lhs, ast_t *rhs, type_t *required_type) +CORD compile_math_method(env_t *env, binop_e op, ast_t *lhs, ast_t *rhs, type_t *required_type) { - // Math methods are things like __add(), __sub(), etc. If we don't find a + // Math methods are things like plus(), minus(), etc. If we don't find a // matching method, return CORD_EMPTY. const char *method_name = binop_method_names[op]; if (!method_name) @@ -1247,36 +1247,62 @@ CORD compile_math_method(env_t *env, ast_t *ast, binop_e op, ast_t *lhs, ast_t * type_t *lhs_t = get_type(env, lhs); type_t *rhs_t = get_type(env, rhs); - for (int64_t i = 1; ; ) { - binding_t *b = get_namespace_binding(env, lhs, method_name); - if (b && b->type->tag == FunctionType) { - auto fn = Match(b->type, FunctionType); - if (fn->args && fn->args->next && can_promote(lhs_t, get_arg_type(env, fn->args)) - && can_promote(rhs_t, get_arg_type(env, fn->args->next)) - && (!required_type || can_promote(fn->ret, required_type))) { - return CORD_all( - b->code, "(", - compile_arguments(env, ast, fn->args, new(arg_ast_t, .value=lhs, .next=new(arg_ast_t, .value=rhs))), - ")"); - } +#define binding_works(b, lhs_t, rhs_t, ret_t) \ + (b && b->type->tag == FunctionType && ({ auto fn = Match(b->type, FunctionType); \ + (type_eq(fn->ret, ret_t) \ + && (fn->args && type_eq(fn->args->type, lhs_t)) \ + && (fn->args->next && can_promote(fn->args->next->type, rhs_t)) \ + && (!required_type || type_eq(required_type, fn->ret))); })) + switch (op) { + case BINOP_MULT: { + if (lhs_t->tag == NumType || lhs_t->tag == IntType) { + binding_t *b = get_namespace_binding(env, rhs, "scaled_by"); + if (binding_works(b, rhs_t, lhs_t, rhs_t)) + return CORD_all(b->code, "(", compile(env, rhs), ", ", compile(env, lhs), ")"); + } else if (rhs_t->tag == NumType || rhs_t->tag == IntType) { + binding_t *b = get_namespace_binding(env, lhs, "scaled_by"); + if (binding_works(b, lhs_t, rhs_t, lhs_t)) + return CORD_all(b->code, "(", compile(env, lhs), ", ", compile(env, rhs), ")"); + } else if (type_eq(lhs_t, rhs_t)) { + binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]); + if (binding_works(b, lhs_t, rhs_t, lhs_t)) + return CORD_all(b->code, "(", compile(env, lhs), ", ", compile(env, rhs), ")"); } - binding_t *b2 = get_namespace_binding(env, rhs, method_name); - if (b2 && b2->type->tag == FunctionType) { - auto fn = Match(b2->type, FunctionType); - if (fn->args && fn->args->next && can_promote(lhs_t, get_arg_type(env, fn->args)) - && can_promote(rhs_t, get_arg_type(env, fn->args->next)) - && (!required_type || can_promote(fn->ret, required_type))) { - return CORD_all( - b2->code, "(", - compile_arguments(env, ast, fn->args, new(arg_ast_t, .value=lhs, .next=new(arg_ast_t, .value=rhs))), - ")"); - } + break; + } + case BINOP_PLUS: case BINOP_MINUS: case BINOP_AND: case BINOP_OR: case BINOP_XOR: { + if (type_eq(lhs_t, rhs_t)) { + binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]); + if (binding_works(b, lhs_t, rhs_t, lhs_t)) + return CORD_all(b->code, "(", compile(env, lhs), ", ", compile(env, rhs), ")"); } - if (!b && !b2) break; - - // If we found __foo, but it didn't match the types, check for - // __foo2, __foo3, etc. until we stop finding methods with that name. - method_name = heap_strf("%s%ld", binop_method_names[op], ++i); + break; + } + case BINOP_DIVIDE: case BINOP_MOD: case BINOP_MOD1: { + if (rhs_t->tag == NumType || rhs_t->tag == IntType) { + binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]); + if (binding_works(b, lhs_t, rhs_t, lhs_t)) + return CORD_all(b->code, "(", compile(env, lhs), ", ", compile(env, rhs), ")"); + } + break; + } + case BINOP_LSHIFT: case BINOP_RSHIFT: { + if (rhs_t->tag == IntType) { + binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]); + if (binding_works(b, lhs_t, rhs_t, lhs_t)) + return CORD_all(b->code, "(", compile(env, lhs), ", ", compile(env, rhs), ")"); + } + break; + } + case BINOP_POWER: { + if (rhs_t->tag == NumType) { + binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]); + if (binding_works(b, lhs_t, rhs_t, lhs_t)) + return CORD_all(b->code, "(", compile(env, lhs), ", ", compile(env, rhs), ")"); + } + break; + } + default: break; } return CORD_EMPTY; } @@ -1327,13 +1353,6 @@ CORD compile(env_t *env, ast_t *ast) } } default: { - binding_t *b = get_namespace_binding(env, expr, "__length"); - if (b && b->type->tag == FunctionType) { - auto fn = Match(b->type, FunctionType); - if (type_eq(fn->ret, INT_TYPE) && fn->args && can_promote(t, get_arg_type(env, fn->args))) - return CORD_all(b->code, "(", compile_arguments(env, ast, fn->args, new(arg_ast_t, .value=expr)), ")"); - } - code_err(ast, "Length is not implemented for %T values", t); } } @@ -1350,8 +1369,15 @@ CORD compile(env_t *env, ast_t *ast) return CORD_all("!(", compile(env, WrapAST(ast, Length, value)), ")"); else if (t->tag == TextType) return CORD_all("!(", compile(env, value), ")"); - else - code_err(ast, "I don't know how to negate values of type %T", t); + + binding_t *b = get_namespace_binding(env, value, "negated"); + if (b && b->type->tag == FunctionType) { + auto fn = Match(b->type, FunctionType); + if (fn->args && can_promote(t, get_arg_type(env, fn->args))) + return CORD_all(b->code, "(", compile_arguments(env, ast, fn->args, new(arg_ast_t, .value=value)), ")"); + } + + code_err(ast, "I don't know how to negate values of type %T", t); } case Negative: { ast_t *value = Match(ast, Negative)->value; @@ -1359,7 +1385,7 @@ CORD compile(env_t *env, ast_t *ast) if (t->tag == IntType || t->tag == NumType) return CORD_all("-(", compile(env, value), ")"); - binding_t *b = get_namespace_binding(env, value, "__negative"); + binding_t *b = get_namespace_binding(env, value, "negative"); if (b && b->type->tag == FunctionType) { auto fn = Match(b->type, FunctionType); if (fn->args && can_promote(t, get_arg_type(env, fn->args))) @@ -1382,7 +1408,7 @@ CORD compile(env_t *env, ast_t *ast) } case BinaryOp: { auto binop = Match(ast, BinaryOp); - CORD method_call = compile_math_method(env, ast, binop->op, binop->lhs, binop->rhs, NULL); + CORD method_call = compile_math_method(env, binop->op, binop->lhs, binop->rhs, NULL); if (method_call != CORD_EMPTY) return method_call; |
