From 6d3d104363426d9d26a3fa65979899c032a093a7 Mon Sep 17 00:00:00 2001 From: Bruce Hill Date: Sat, 10 Aug 2024 20:50:15 -0400 Subject: Overhaul of operator metamethods --- ast.c | 8 ++-- compile.c | 112 +++++++++++++++++++++++++++------------------ docs/operators.md | 127 ++++++++++++++++++++++++++++++++++++++++++++++++++++ parse.c | 5 ++- test/metamethods.tm | 72 +++++++++++++++++++++++------ typecheck.c | 77 ++++++++++++++++++++----------- 6 files changed, 312 insertions(+), 89 deletions(-) create mode 100644 docs/operators.md diff --git a/ast.c b/ast.c index 0f325320..515ffa7b 100644 --- a/ast.c +++ b/ast.c @@ -18,10 +18,10 @@ static const char *OP_NAMES[] = { }; const char *binop_method_names[BINOP_XOR+1] = { - [BINOP_POWER]="__power", [BINOP_MULT]="__multiply", [BINOP_DIVIDE]="__divide", - [BINOP_MOD]="__mod", [BINOP_MOD1]="__mod1", [BINOP_PLUS]="__add", [BINOP_MINUS]="__subtract", - [BINOP_CONCAT]="__concatenate", [BINOP_LSHIFT]="__left_shift", [BINOP_RSHIFT]="__right_shift", - [BINOP_AND]="__and", [BINOP_OR]="__or", [BINOP_XOR]="__xor", + [BINOP_POWER]="power", [BINOP_MULT]="times", [BINOP_DIVIDE]="divided_by", + [BINOP_MOD]="modulo", [BINOP_MOD1]="modulo1", [BINOP_PLUS]="plus", [BINOP_MINUS]="minus", + [BINOP_CONCAT]="concatenated_with", [BINOP_LSHIFT]="left_shifted", [BINOP_RSHIFT]="right_shifted", + [BINOP_AND]="bit_and", [BINOP_OR]="bit_or", [BINOP_XOR]="bit_xor", }; static CORD ast_list_to_xml(ast_list_t *asts); diff --git a/compile.c b/compile.c index eedf6c98..d2081b4c 100644 --- a/compile.c +++ b/compile.c @@ -16,7 +16,7 @@ static CORD compile_to_pointer_depth(env_t *env, ast_t *ast, int64_t target_depth, bool allow_optional); static env_t *with_enum_scope(env_t *env, type_t *t); -static CORD compile_math_method(env_t *env, ast_t *ast, binop_e op, ast_t *lhs, ast_t *rhs, type_t *required_type); +static CORD compile_math_method(env_t *env, binop_e op, ast_t *lhs, ast_t *rhs, type_t *required_type); static CORD compile_string(env_t *env, ast_t *ast, CORD color); static CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t *call_args); @@ -445,7 +445,7 @@ CORD compile_statement(env_t *env, ast_t *ast) auto update = Match(ast, UpdateAssign); CORD lhs = compile_lvalue(env, update->lhs); - CORD method_call = compile_math_method(env, ast, update->op, update->lhs, update->rhs, get_type(env, update->lhs)); + CORD method_call = compile_math_method(env, update->op, update->lhs, update->rhs, get_type(env, update->lhs)); if (method_call) return CORD_all(lhs, " = ", method_call, ";"); @@ -1237,9 +1237,9 @@ CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t return code; } -CORD compile_math_method(env_t *env, ast_t *ast, binop_e op, ast_t *lhs, ast_t *rhs, type_t *required_type) +CORD compile_math_method(env_t *env, binop_e op, ast_t *lhs, ast_t *rhs, type_t *required_type) { - // Math methods are things like __add(), __sub(), etc. If we don't find a + // Math methods are things like plus(), minus(), etc. If we don't find a // matching method, return CORD_EMPTY. const char *method_name = binop_method_names[op]; if (!method_name) @@ -1247,36 +1247,62 @@ CORD compile_math_method(env_t *env, ast_t *ast, binop_e op, ast_t *lhs, ast_t * type_t *lhs_t = get_type(env, lhs); type_t *rhs_t = get_type(env, rhs); - for (int64_t i = 1; ; ) { - binding_t *b = get_namespace_binding(env, lhs, method_name); - if (b && b->type->tag == FunctionType) { - auto fn = Match(b->type, FunctionType); - if (fn->args && fn->args->next && can_promote(lhs_t, get_arg_type(env, fn->args)) - && can_promote(rhs_t, get_arg_type(env, fn->args->next)) - && (!required_type || can_promote(fn->ret, required_type))) { - return CORD_all( - b->code, "(", - compile_arguments(env, ast, fn->args, new(arg_ast_t, .value=lhs, .next=new(arg_ast_t, .value=rhs))), - ")"); - } +#define binding_works(b, lhs_t, rhs_t, ret_t) \ + (b && b->type->tag == FunctionType && ({ auto fn = Match(b->type, FunctionType); \ + (type_eq(fn->ret, ret_t) \ + && (fn->args && type_eq(fn->args->type, lhs_t)) \ + && (fn->args->next && can_promote(fn->args->next->type, rhs_t)) \ + && (!required_type || type_eq(required_type, fn->ret))); })) + switch (op) { + case BINOP_MULT: { + if (lhs_t->tag == NumType || lhs_t->tag == IntType) { + binding_t *b = get_namespace_binding(env, rhs, "scaled_by"); + if (binding_works(b, rhs_t, lhs_t, rhs_t)) + return CORD_all(b->code, "(", compile(env, rhs), ", ", compile(env, lhs), ")"); + } else if (rhs_t->tag == NumType || rhs_t->tag == IntType) { + binding_t *b = get_namespace_binding(env, lhs, "scaled_by"); + if (binding_works(b, lhs_t, rhs_t, lhs_t)) + return CORD_all(b->code, "(", compile(env, lhs), ", ", compile(env, rhs), ")"); + } else if (type_eq(lhs_t, rhs_t)) { + binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]); + if (binding_works(b, lhs_t, rhs_t, lhs_t)) + return CORD_all(b->code, "(", compile(env, lhs), ", ", compile(env, rhs), ")"); } - binding_t *b2 = get_namespace_binding(env, rhs, method_name); - if (b2 && b2->type->tag == FunctionType) { - auto fn = Match(b2->type, FunctionType); - if (fn->args && fn->args->next && can_promote(lhs_t, get_arg_type(env, fn->args)) - && can_promote(rhs_t, get_arg_type(env, fn->args->next)) - && (!required_type || can_promote(fn->ret, required_type))) { - return CORD_all( - b2->code, "(", - compile_arguments(env, ast, fn->args, new(arg_ast_t, .value=lhs, .next=new(arg_ast_t, .value=rhs))), - ")"); - } + break; + } + case BINOP_PLUS: case BINOP_MINUS: case BINOP_AND: case BINOP_OR: case BINOP_XOR: { + if (type_eq(lhs_t, rhs_t)) { + binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]); + if (binding_works(b, lhs_t, rhs_t, lhs_t)) + return CORD_all(b->code, "(", compile(env, lhs), ", ", compile(env, rhs), ")"); } - if (!b && !b2) break; - - // If we found __foo, but it didn't match the types, check for - // __foo2, __foo3, etc. until we stop finding methods with that name. - method_name = heap_strf("%s%ld", binop_method_names[op], ++i); + break; + } + case BINOP_DIVIDE: case BINOP_MOD: case BINOP_MOD1: { + if (rhs_t->tag == NumType || rhs_t->tag == IntType) { + binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]); + if (binding_works(b, lhs_t, rhs_t, lhs_t)) + return CORD_all(b->code, "(", compile(env, lhs), ", ", compile(env, rhs), ")"); + } + break; + } + case BINOP_LSHIFT: case BINOP_RSHIFT: { + if (rhs_t->tag == IntType) { + binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]); + if (binding_works(b, lhs_t, rhs_t, lhs_t)) + return CORD_all(b->code, "(", compile(env, lhs), ", ", compile(env, rhs), ")"); + } + break; + } + case BINOP_POWER: { + if (rhs_t->tag == NumType) { + binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]); + if (binding_works(b, lhs_t, rhs_t, lhs_t)) + return CORD_all(b->code, "(", compile(env, lhs), ", ", compile(env, rhs), ")"); + } + break; + } + default: break; } return CORD_EMPTY; } @@ -1327,13 +1353,6 @@ CORD compile(env_t *env, ast_t *ast) } } default: { - binding_t *b = get_namespace_binding(env, expr, "__length"); - if (b && b->type->tag == FunctionType) { - auto fn = Match(b->type, FunctionType); - if (type_eq(fn->ret, INT_TYPE) && fn->args && can_promote(t, get_arg_type(env, fn->args))) - return CORD_all(b->code, "(", compile_arguments(env, ast, fn->args, new(arg_ast_t, .value=expr)), ")"); - } - code_err(ast, "Length is not implemented for %T values", t); } } @@ -1350,8 +1369,15 @@ CORD compile(env_t *env, ast_t *ast) return CORD_all("!(", compile(env, WrapAST(ast, Length, value)), ")"); else if (t->tag == TextType) return CORD_all("!(", compile(env, value), ")"); - else - code_err(ast, "I don't know how to negate values of type %T", t); + + binding_t *b = get_namespace_binding(env, value, "negated"); + if (b && b->type->tag == FunctionType) { + auto fn = Match(b->type, FunctionType); + if (fn->args && can_promote(t, get_arg_type(env, fn->args))) + return CORD_all(b->code, "(", compile_arguments(env, ast, fn->args, new(arg_ast_t, .value=value)), ")"); + } + + code_err(ast, "I don't know how to negate values of type %T", t); } case Negative: { ast_t *value = Match(ast, Negative)->value; @@ -1359,7 +1385,7 @@ CORD compile(env_t *env, ast_t *ast) if (t->tag == IntType || t->tag == NumType) return CORD_all("-(", compile(env, value), ")"); - binding_t *b = get_namespace_binding(env, value, "__negative"); + binding_t *b = get_namespace_binding(env, value, "negative"); if (b && b->type->tag == FunctionType) { auto fn = Match(b->type, FunctionType); if (fn->args && can_promote(t, get_arg_type(env, fn->args))) @@ -1382,7 +1408,7 @@ CORD compile(env_t *env, ast_t *ast) } case BinaryOp: { auto binop = Match(ast, BinaryOp); - CORD method_call = compile_math_method(env, ast, binop->op, binop->lhs, binop->rhs, NULL); + CORD method_call = compile_math_method(env, binop->op, binop->lhs, binop->rhs, NULL); if (method_call != CORD_EMPTY) return method_call; diff --git a/docs/operators.md b/docs/operators.md new file mode 100644 index 00000000..6716b944 --- /dev/null +++ b/docs/operators.md @@ -0,0 +1,127 @@ +# Operator Overloading + +Operator overloading is supported, but _strongly discouraged_. Operator +overloading should only be used for types that represent mathematical concepts +that users can be reliably expected to understand how they behave with math +operators, and for which the implementations are extremely efficient. Operator +overloading should not be used to hide expensive computations or to create +domain-specific syntax to make certain operations more concise. Examples of +good candidates for operator overloading would include: + +- Mathematical vectors +- Matrices +- Quaternions +- Complex numbers + +Bad candidates would include: + +- Arbitrarily sized datastructures +- Objects that represent regular expressions +- Objects that represent filesystem paths + +## Available Operator Overloads + +When performing a math operation between any two types that are not both +numerical or boolean, the compiler will look up the appropriate method and +insert a function call to that method. Math overload operations are all assumed +to return a value that is the same type as the first argument and the second +argument must be either the same type as the first argument or a number, +depending on the specifications of the specific operator. + +### Addition + +``` +func plus(T, T)->T +``` + +The `+` operator invokes the `plus()` method, which takes two objects of the +same type and returns a new value of the same type. + +### Subtraction + +``` +func minus(T, T)->T +``` + +The `-` operator invokes the `minus()` method, which takes two objects of the +same type and returns a new value of the same type. + +### Multiplication + +``` +func times(T, T)->T +``` + +The `*` operator invokes the `times()` method, which takes two objects of the +same type and returns a new value of the same type. This should _not_ be used +to implement either a dot product or a cross product. Dot products and cross +products should be implemented as explicitly named methods. + +``` +func scaled_by(T, N)->T +``` + +In a multiplication expression, `a*b`, if either `a` or `b` has type `T` and +the other has a numeric type `N` (either `Int8`, `Int16`, `Int32`, `Int`, +`Num32`, or `Num`), then the method `scaled_by()` will be invoked. + +### Division + +``` +func divided_by(T, N)->T +``` + +In a division expression, `a/b`, if `a` has type `T` and `b` has a numeric +type `N`, then the method `divided_by()` will be invoked. + +### Exponentiation + +``` +func power(T, N)->T +``` + +In an exponentiation expression, `a^b`, if `a` has type `T` and `b` has a +numeric type `N`, then the method `power()` will be invoked. + +### Modulus + +``` +func mod(T, N)->T +func mod1(T, N)->T +``` + +In a modulus expression, `a mod b` or `a mod1 b`, if `a` has type `T` and `b` +has a numeric type `N`, then the method `mod()` or `mod1()` will be invoked. + +### Negative + +``` +func negative(T)->T +``` + +In a unary negative expression `-x`, the method `negative()` will be invoked. + +### Bit Operations + +``` +func left_shift(T, Int)->T +func right_shift(T, Int)->T +func bit_and(T, T)->T +func bit_or(T, T)->T +func bit_xor(T, T)->T +``` + +In a bit shifting expression, `a >> b` or `a << b`, if `a` has type `T` and `b` +is an `Int`, then the method `left_shift()` or `right_shift()` will be invoked. + +In a bitwise binary operation `a and b`, `a or b`, or `a xor b`, then the +method `bit_and()`, `bit_or()`, or `bit_xor()` will be invoked. + +### Bitwise Negation + +``` +func negated(T)->T +``` + +In a unary bitwise negation expression `not x`, the method `negated()` will be +invoked. diff --git a/parse.c b/parse.c index 4c9e95aa..60b68edc 100644 --- a/parse.c +++ b/parse.c @@ -2096,7 +2096,10 @@ PARSER(parse_doctest) { *output_end = pos + strcspn(pos, "\r\n"); if (output_end <= output_start) parser_err(ctx, output_start, output_end, "You're missing expected output here"); - output = GC_strndup(output_start, (size_t)(output_end - output_start)); + int64_t trailing_spaces = 0; + while (output_end - trailing_spaces - 1 > output_start && (output_end[-trailing_spaces-1] == ' ' || output_end[-trailing_spaces-1] == '\t')) + ++trailing_spaces; + output = GC_strndup(output_start, (size_t)(output_end - output_start) - trailing_spaces); pos = output_end; } else { pos = expr->end; diff --git a/test/metamethods.tm b/test/metamethods.tm index 6a870ee9..d20670b2 100644 --- a/test/metamethods.tm +++ b/test/metamethods.tm @@ -1,27 +1,51 @@ struct Vec2(x,y:Int): - func __add(a,b:Vec2; inline)->Vec2: + func plus(a,b:Vec2; inline)->Vec2: return Vec2(a.x+b.x, a.y+b.y) - func __subtract(a,b:Vec2; inline)->Vec2: + func minus(a,b:Vec2; inline)->Vec2: return Vec2(a.x-b.x, a.y-b.y) - func __multiply(a,b:Vec2; inline)->Int: + func dot(a,b:Vec2; inline)->Int: return a.x*b.x + a.y*b.y - func __multiply2(a:Vec2,b:Int; inline)->Vec2: - return Vec2(a.x*b, a.y*b) + func scaled_by(a:Vec2, k:Int; inline)->Vec2: + return Vec2(a.x*k, a.y*k) - func __multiply3(a:Int,b:Vec2; inline)->Vec2: - return Vec2(a*b.x, a*b.y) - - func __multiply4(a,b:Vec2; inline)->Vec2: + func times(a,b:Vec2; inline)->Vec2: return Vec2(a.x*b.x, a.y*b.y) - func __negative(v:Vec2; inline)->Vec2: + func divided_by(a:Vec2, k:Int; inline)->Vec2: + return Vec2(a.x/k, a.y/k) + + func negative(v:Vec2; inline)->Vec2: return Vec2(-v.x, -v.y) - func __length(v:Vec2; inline)->Int: - return 2 + func negated(v:Vec2; inline)->Vec2: + return Vec2(not v.x, not v.y) + + func bit_and(a,b:Vec2; inline)->Vec2: + return Vec2(a.x and b.x, a.y and b.y) + + func bit_or(a,b:Vec2; inline)->Vec2: + return Vec2(a.x or b.x, a.y or b.y) + + func bit_xor(a,b:Vec2; inline)->Vec2: + return Vec2(a.x xor b.x, a.y xor b.y) + + func left_shifted(v:Vec2, bits:Int; inline)->Vec2: + return Vec2(v.x >> bits, v.y >> bits) + + func right_shifted(v:Vec2, bits:Int; inline)->Vec2: + return Vec2(v.x << bits, v.y << bits) + + func modulo(v:Vec2, modulus:Int; inline)->Vec2: + return Vec2(v.x mod modulus, v.y mod modulus) + + func modulo1(v:Vec2, modulus:Int; inline)->Vec2: + return Vec2(v.x mod1 modulus, v.y mod1 modulus) + + func power(v:Vec2, exponent:Num; inline)->Vec2: + return Vec2(Int(v.x ^ exponent), Int(v.y ^ exponent)) func main(): >> x := Vec2(10, 20) @@ -31,6 +55,8 @@ func main(): >> x - y = Vec2(x=-90, y=-180) >> x * y + = Vec2(x=1000, y=4000) + >> x:dot(y) = 5000 >> x * -1 = Vec2(x=-10, y=-20) @@ -43,9 +69,27 @@ func main(): >> x *= Vec2(10, -1) = Vec2(x=110, y=-22) + >> x *= -1 + = Vec2(x=-110, y=22) + >> x = Vec2(1, 2) >> -x = Vec2(x=-1, y=-2) - >> #x - = 2 + + x = Vec2(1, 2) + y = Vec2(4, 3) + >> x and y + = Vec2(x=0, y=2) + >> x or y + = Vec2(x=5, y=3) + >> x xor y + = Vec2(x=5, y=1) + >> x / 2 + = Vec2(x=0, y=1) + >> x mod 3 + = Vec2(x=1, y=2) + >> x mod1 3 + = Vec2(x=1, y=2) + >> x^2.0 + = Vec2(x=1, y=4) diff --git a/typecheck.c b/typecheck.c index d8691544..2d1ba40a 100644 --- a/typecheck.c +++ b/typecheck.c @@ -811,11 +811,11 @@ type_t *get_type(env_t *env, ast_t *ast) if (t->tag == IntType || t->tag == NumType) return t; - binding_t *b = get_namespace_binding(env, value, "__negative"); + binding_t *b = get_namespace_binding(env, value, "negative"); if (b && b->type->tag == FunctionType) { auto fn = Match(b->type, FunctionType); - if (fn->args && can_promote(t, get_arg_type(env, fn->args))) - return fn->ret; + if (fn->args && type_eq(t, get_arg_type(env, fn->args)) && type_eq(t, fn->ret)) + return t; } code_err(ast, "I don't know how to get the negative value of type %T", t); @@ -826,6 +826,14 @@ type_t *get_type(env_t *env, ast_t *ast) return t; if (t->tag == PointerType && Match(t, PointerType)->is_optional) return Type(BoolType); + + ast_t *value = Match(ast, Not)->value; + binding_t *b = get_namespace_binding(env, value, "negated"); + if (b && b->type->tag == FunctionType) { + auto fn = Match(b->type, FunctionType); + if (fn->args && type_eq(t, get_arg_type(env, fn->args)) && type_eq(t, fn->ret)) + return t; + } code_err(ast, "I only know how to get 'not' of boolean, numeric, and optional pointer types, not %T", t); } case BinaryOp: { @@ -833,31 +841,46 @@ type_t *get_type(env_t *env, ast_t *ast) type_t *lhs_t = get_type(env, binop->lhs), *rhs_t = get_type(env, binop->rhs); - // Check for a binop method like __add() etc: - const char *method_name = binop_method_names[binop->op]; - if (method_name) { - for (int64_t n = 1; ; ) { - binding_t *b = get_namespace_binding(env, binop->lhs, method_name); - if (b && b->type->tag == FunctionType) { - auto fn = Match(b->type, FunctionType); - if (fn->args && fn->args->next && can_promote(lhs_t, get_arg_type(env, fn->args)) - && can_promote(rhs_t, get_arg_type(env, fn->args->next))) - return fn->ret; - } - binding_t *b2 = get_namespace_binding(env, binop->rhs, method_name); - if (b2 && b2->type->tag == FunctionType) { - auto fn = Match(b2->type, FunctionType); - if (fn->args && fn->args->next && can_promote(lhs_t, get_arg_type(env, fn->args)) - && can_promote(rhs_t, get_arg_type(env, fn->args->next))) - return fn->ret; - } - if (!b && !b2) break; - - // If we found __foo, but it didn't match the types, check for - // __foo2, __foo3, etc. until we stop finding methods with that name. - method_name = heap_strf("%s%ld", binop_method_names[binop->op], ++n); - } +#define binding_works(name, self, lhs_t, rhs_t, ret_t) \ + ({ binding_t *b = get_namespace_binding(env, self, name); \ + (b && b->type->tag == FunctionType && ({ auto fn = Match(b->type, FunctionType); \ + (type_eq(fn->ret, ret_t) \ + && (fn->args && type_eq(fn->args->type, lhs_t)) \ + && (fn->args->next && can_promote(fn->args->next->type, rhs_t))); })); }) + // Check for a binop method like plus() etc: + switch (binop->op) { + case BINOP_MULT: { + if ((lhs_t->tag == NumType || lhs_t->tag == IntType) && binding_works("scaled_by", binop->rhs, rhs_t, lhs_t, rhs_t)) + return rhs_t; + else if ((rhs_t->tag == NumType || rhs_t->tag == IntType) && binding_works("scaled_by", binop->lhs, lhs_t, rhs_t, lhs_t)) + return lhs_t; + else if (type_eq(lhs_t, rhs_t) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t)) + return lhs_t; + break; + } + case BINOP_PLUS: case BINOP_MINUS: case BINOP_AND: case BINOP_OR: case BINOP_XOR: { + if (type_eq(lhs_t, rhs_t) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t)) + return lhs_t; + break; + } + case BINOP_DIVIDE: case BINOP_MOD: case BINOP_MOD1: { + if ((rhs_t->tag == NumType || rhs_t->tag == IntType) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t)) + return lhs_t; + break; + } + case BINOP_LSHIFT: case BINOP_RSHIFT: { + if (rhs_t->tag == IntType && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t)) + return lhs_t; + break; + } + case BINOP_POWER: { + if (rhs_t->tag == NumType && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t)) + return lhs_t; + break; + } + default: break; } +#undef binding_works switch (binop->op) { case BINOP_AND: { -- cgit v1.2.3