From 92a593b80fe935eb21615dc45b4d7868b254bec6 Mon Sep 17 00:00:00 2001 From: Bruce Hill Date: Sat, 2 Nov 2024 22:34:35 -0400 Subject: [PATCH] Support reductions for comparison operators like == and < --- compile.c | 115 +++++++++++++++++++++++++++++---------------- test/reductions.tm | 17 +++++-- typecheck.c | 6 +++ types.c | 20 ++++++++ types.h | 1 + 5 files changed, 114 insertions(+), 45 deletions(-) diff --git a/compile.c b/compile.c index ff3cdf7..d947fe2 100644 --- a/compile.c +++ b/compile.c @@ -3153,49 +3153,82 @@ CORD compile(env_t *env, ast_t *ast) auto reduction = Match(ast, Reduction); type_t *optional_t = get_type(env, ast); type_t *t = Match(optional_t, OptionalType)->type; - CORD code = CORD_all( - "({ // Reduction:\n", - compile_declaration(t, "reduction"), ";\n" - "Bool_t has_value = no;\n" - ); - env_t *scope = fresh_scope(env); - set_binding(scope, "$reduction", new(binding_t, .type=t, .code="reduction")); - ast_t *item = FakeAST(Var, "$iter_value"); - ast_t *body = FakeAST(InlineCCode, .code="{}"); // placeholder - ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body); - env_t *body_scope = for_scope(scope, loop); + binop_e op = reduction->combination->tag == BinaryOp ? Match(reduction->combination, BinaryOp)->op : BINOP_UNKNOWN; - // For the special case of (or)/(and), we need to early out if we can: - CORD early_out = CORD_EMPTY; - if (reduction->combination->tag == BinaryOp) { - auto binop = Match(reduction->combination, BinaryOp); - if (t->tag != BoolType && (binop->op == BINOP_EQ || binop->op == BINOP_NE - || binop->op == BINOP_LT || binop->op == BINOP_LE - || binop->op == BINOP_GT || binop->op == BINOP_GE)) - code_err(ast, "Reductions are not supported for this type of infix operator"); - else if ((t->tag != IntType || Match(t, IntType)->bits != TYPE_IBITS32) && binop->op == BINOP_CMP) - code_err(ast, "<> reductions are only supported for Int32 values"); - else if (t->tag == BoolType && binop->op == BINOP_AND) - early_out = "if (!reduction) break;"; - else if (t->tag == BoolType && binop->op == BINOP_OR) - early_out = "if (reduction) break;"; - else if (t->tag == OptionalType && binop->op == BINOP_AND) - early_out = CORD_all("if (", check_null(t, "reduction"), ") break;"); - else if (t->tag == OptionalType && binop->op == BINOP_OR) - early_out = CORD_all("if (!", check_null(t, "reduction"), ") break;"); + if (op == BINOP_EQ || op == BINOP_NE || op == BINOP_LT || op == BINOP_LE || op == BINOP_GT || op == BINOP_GE) { + // Chained comparisons like ==, <, etc. + type_t *iter_t = get_type(env, reduction->iter); + type_t *item_t = get_iterated_type(iter_t); + if (!item_t) code_err(reduction->iter, "I couldn't figure out how to iterate over this type: %T", iter_t); + CORD code = CORD_all( + "({ // Reduction:\n", + compile_declaration(item_t, "prev"), ";\n" + "OptionalBool_t result = NULL_BOOL;\n" + ); + env_t *scope = fresh_scope(env); + set_binding(scope, "$reduction", new(binding_t, .type=item_t, .code="prev")); + ast_t *item = FakeAST(Var, "$iter_value"); + ast_t *body = FakeAST(InlineCCode, .code="{}"); // placeholder + ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body); + env_t *body_scope = for_scope(scope, loop); + + body->__data.InlineCCode.code = CORD_all( + "if (result == NULL_BOOL) {\n" + " prev = ", compile(body_scope, item), ";\n" + " result = yes;\n" + "} else {\n" + " if (", compile(body_scope, reduction->combination), ") {\n", + " prev = ", compile(body_scope, item), ";\n", + " } else {\n" + " result = no;\n", + " break;\n", + " }\n", + "}\n"); + code = CORD_all(code, compile_statement(scope, loop), "\nresult;})"); + return code; + } else { + // Accumulator-style reductions like +, ++, *, etc. + CORD code = CORD_all( + "({ // Reduction:\n", + compile_declaration(t, "reduction"), ";\n" + "Bool_t has_value = no;\n" + ); + env_t *scope = fresh_scope(env); + set_binding(scope, "$reduction", new(binding_t, .type=t, .code="reduction")); + ast_t *item = FakeAST(Var, "$iter_value"); + ast_t *body = FakeAST(InlineCCode, .code="{}"); // placeholder + ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body); + env_t *body_scope = for_scope(scope, loop); + + // For the special case of (or)/(and), we need to early out if we can: + CORD early_out = CORD_EMPTY; + if (op == BINOP_CMP) { + if (t->tag != IntType || Match(t, IntType)->bits != TYPE_IBITS32) + code_err(ast, "<> reductions are only supported for Int32 values"); + } else if (op == BINOP_AND) { + if (t->tag == BoolType) + early_out = "if (!reduction) break;"; + else if (t->tag == OptionalType) + early_out = CORD_all("if (", check_null(t, "reduction"), ") break;"); + } else if (op == BINOP_OR) { + if (t->tag == BoolType) + early_out = "if (reduction) break;"; + else if (t->tag == OptionalType) + early_out = CORD_all("if (!", check_null(t, "reduction"), ") break;"); + } + + body->__data.InlineCCode.code = CORD_all( + "if (!has_value) {\n" + " reduction = ", compile(body_scope, item), ";\n" + " has_value = yes;\n" + "} else {\n" + " reduction = ", compile(body_scope, reduction->combination), ";\n", + early_out, + "}\n"); + + code = CORD_all(code, compile_statement(scope, loop), "\nhas_value ? ", promote_to_optional(t, "reduction"), " : ", compile_null(t), ";})"); + return code; } - - body->__data.InlineCCode.code = CORD_all( - "if (!has_value) {\n" - " reduction = ", compile(body_scope, item), ";\n" - " has_value = yes;\n" - "} else {\n" - " reduction = ", compile(body_scope, reduction->combination), ";\n", - early_out, - "}\n"); - - code = CORD_all(code, compile_statement(scope, loop), "\nhas_value ? ", promote_to_optional(t, "reduction"), " : ", compile_null(t), ";})"); - return code; } case FieldAccess: { auto f = Match(ast, FieldAccess); diff --git a/test/reductions.tm b/test/reductions.tm index 7bfe212..a844081 100644 --- a/test/reductions.tm +++ b/test/reductions.tm @@ -27,8 +27,17 @@ func main(): = Foo(x=0, y=-999) !! (or) and (and) have early out behavior: - >> (or: i == 3 for i in 9999999999999999999999999999)! - = yes + >> (or: i == 3 for i in 9999999999999999999999999999)! + = yes - >> (and: i < 10 for i in 9999999999999999999999999999)! - = no + >> (and: i < 10 for i in 9999999999999999999999999999)! + = no + + >> (<=: [1, 2, 2, 3, 4])! + = yes + + >> (<=: [:Int]) + = !Bool + + >> (<=: [5, 4, 3, 2, 1])! + = no diff --git a/typecheck.c b/typecheck.c index bc3c798..bff7d08 100644 --- a/typecheck.c +++ b/typecheck.c @@ -1033,6 +1033,12 @@ type_t *get_type(env_t *env, ast_t *ast) auto reduction = Match(ast, Reduction); type_t *iter_t = get_type(env, reduction->iter); + if (reduction->combination && reduction->combination->tag == BinaryOp) { + binop_e op = Match(reduction->combination, BinaryOp)->op; + if (op == BINOP_EQ || op == BINOP_NE || op == BINOP_LT || op == BINOP_LE || op == BINOP_GT || op == BINOP_GE) + return Type(OptionalType, .type=Type(BoolType)); + } + type_t *value_t; type_t *iter_value_t = value_type(iter_t); switch (iter_value_t->tag) { diff --git a/types.c b/types.c index bcb0e9f..ddb3f07 100644 --- a/types.c +++ b/types.c @@ -626,4 +626,24 @@ type_t *get_field_type(type_t *t, const char *field_name) } } +PUREFUNC type_t *get_iterated_type(type_t *t) +{ + type_t *iter_value_t = value_type(t); + switch (iter_value_t->tag) { + case BigIntType: case IntType: return iter_value_t; break; + case ArrayType: return Match(iter_value_t, ArrayType)->item_type; break; + case SetType: return Match(iter_value_t, SetType)->item_type; break; + case TableType: return NULL; + case FunctionType: case ClosureType: { + // Iterator function + auto fn = iter_value_t->tag == ClosureType ? + Match(Match(iter_value_t, ClosureType)->fn, FunctionType) : Match(iter_value_t, FunctionType); + if (fn->args || fn->ret->tag != OptionalType) + return NULL; + return Match(fn->ret, OptionalType)->type; + } + default: return NULL; + } +} + // vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0 diff --git a/types.h b/types.h index 96e0b6c..c983dc2 100644 --- a/types.h +++ b/types.h @@ -153,5 +153,6 @@ PUREFUNC size_t type_size(type_t *t); PUREFUNC size_t type_align(type_t *t); PUREFUNC size_t unpadded_struct_size(type_t *t); type_t *get_field_type(type_t *t, const char *field_name); +PUREFUNC type_t *get_iterated_type(type_t *t); // vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0