From f0b1a0f227cb7545af9ec27fbb4742b8c2c03bcd Mon Sep 17 00:00:00 2001 From: Bruce Hill Date: Sat, 5 Apr 2025 00:53:11 -0400 Subject: Fix metamethods for scaled_by and divided_by --- src/compile.c | 32 ++++++++++++++++++++++++++++++++ src/typecheck.c | 48 ++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 70 insertions(+), 10 deletions(-) diff --git a/src/compile.c b/src/compile.c index d618de71..33621bd5 100644 --- a/src/compile.c +++ b/src/compile.c @@ -576,6 +576,38 @@ static CORD compile_binary_op(env_t *env, ast_t *ast) return CORD_all(b->code, "(", compile_arguments(env, ast, fn->args, args), ")"); } + if (ast->tag == Multiply && is_numeric_type(lhs_t)) { + b = get_namespace_binding(env, binop.rhs, "scaled_by"); + if (b && b->type->tag == FunctionType) { + auto fn = Match(b->type, FunctionType); + if (type_eq(fn->ret, rhs_t)) { + arg_ast_t *args = new(arg_ast_t, .value=binop.rhs, .next=new(arg_ast_t, .value=binop.lhs)); + if (is_valid_call(env, fn->args, args, true)) + return CORD_all(b->code, "(", compile_arguments(env, ast, fn->args, args), ")"); + } + } + } else if (ast->tag == Multiply && is_numeric_type(rhs_t)) { + b = get_namespace_binding(env, binop.lhs, "scaled_by"); + if (b && b->type->tag == FunctionType) { + auto fn = Match(b->type, FunctionType); + if (type_eq(fn->ret, lhs_t)) { + arg_ast_t *args = new(arg_ast_t, .value=binop.lhs, .next=new(arg_ast_t, .value=binop.rhs)); + if (is_valid_call(env, fn->args, args, true)) + return CORD_all(b->code, "(", compile_arguments(env, ast, fn->args, args), ")"); + } + } + } else if (ast->tag == Divide && is_numeric_type(rhs_t)) { + b = get_namespace_binding(env, binop.lhs, "divided_by"); + if (b && b->type->tag == FunctionType) { + auto fn = Match(b->type, FunctionType); + if (type_eq(fn->ret, lhs_t)) { + arg_ast_t *args = new(arg_ast_t, .value=binop.lhs, .next=new(arg_ast_t, .value=binop.rhs)); + if (is_valid_call(env, fn->args, args, true)) + return CORD_all(b->code, "(", compile_arguments(env, ast, fn->args, args), ")"); + } + } + } + if (ast->tag == Or && lhs_t->tag == OptionalType) { if (is_incomplete_type(rhs_t)) { type_t *complete = most_complete_type(rhs_t, Match(lhs_t, OptionalType)->type); diff --git a/src/typecheck.c b/src/typecheck.c index 8d4cc946..0291b748 100644 --- a/src/typecheck.c +++ b/src/typecheck.c @@ -1183,20 +1183,48 @@ type_t *get_type(env_t *env, ast_t *ast) } } + if (ast->tag == Multiply && is_numeric_type(lhs_t)) { + binding_t *b = get_namespace_binding(env, binop.rhs, "scaled_by"); + if (b && b->type->tag == FunctionType) { + auto fn = Match(b->type, FunctionType); + if (type_eq(fn->ret, rhs_t)) { + arg_ast_t *args = new(arg_ast_t, .value=binop.rhs, .next=new(arg_ast_t, .value=binop.lhs)); + if (is_valid_call(env, fn->args, args, true)) + return rhs_t; + } + } + } else if (ast->tag == Multiply && is_numeric_type(rhs_t)) { + binding_t *b = get_namespace_binding(env, binop.lhs, "scaled_by"); + if (b && b->type->tag == FunctionType) { + auto fn = Match(b->type, FunctionType); + if (type_eq(fn->ret, lhs_t)) { + arg_ast_t *args = new(arg_ast_t, .value=binop.lhs, .next=new(arg_ast_t, .value=binop.rhs)); + if (is_valid_call(env, fn->args, args, true)) + return lhs_t; + } + } + } else if (ast->tag == Divide && is_numeric_type(rhs_t)) { + binding_t *b = get_namespace_binding(env, binop.lhs, "divided_by"); + if (b && b->type->tag == FunctionType) { + auto fn = Match(b->type, FunctionType); + if (type_eq(fn->ret, lhs_t)) { + arg_ast_t *args = new(arg_ast_t, .value=binop.lhs, .next=new(arg_ast_t, .value=binop.rhs)); + if (is_valid_call(env, fn->args, args, true)) + return lhs_t; + } + } + } + type_t *overall_t = (can_promote(rhs_t, lhs_t) ? lhs_t : (can_promote(lhs_t, rhs_t) ? rhs_t : NULL)); - if (ast->tag == Multiply || ast->tag == Divide) { - binding_t *b = is_numeric_type(lhs_t) ? get_metamethod_binding(env, ast->tag, binop.lhs, binop.rhs, lhs_t) - : get_metamethod_binding(env, ast->tag, binop.rhs, binop.lhs, rhs_t); - if (b) return overall_t; - } else { - if (overall_t == NULL) - code_err(ast, "I don't know how to do math operations between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + if (overall_t == NULL) + code_err(ast, "I don't know how to do math operations between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + + binding_t *b = get_metamethod_binding(env, ast->tag, binop.lhs, binop.rhs, overall_t); + if (b) return overall_t; - binding_t *b = get_metamethod_binding(env, ast->tag, binop.lhs, binop.rhs, overall_t); - if (b) return overall_t; - } if (is_numeric_type(lhs_t) && is_numeric_type(rhs_t)) return overall_t; + code_err(ast, "I don't know how to do math operations between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); } case Concat: { -- cgit v1.2.3