aboutsummaryrefslogtreecommitdiff
path: root/compile.c
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2024-11-08 14:10:19 -0500
committerBruce Hill <bruce@bruce-hill.com>2024-11-08 14:10:19 -0500
commit5d35f286336878a3529dabdb3f7800b6f84712eb (patch)
treeee21c66d28027e84fd31080c145978fba18fec89 /compile.c
parent9c842201f312edd483ee99dcf3e321bdac2a7073 (diff)
Improve reductions so they work better nested and also have bespoke code
optimized for min/max and argmin/argmax.
Diffstat (limited to 'compile.c')
-rw-r--r--compile.c71
1 files changed, 54 insertions, 17 deletions
diff --git a/compile.c b/compile.c
index c6f23d4e..372cd91d 100644
--- a/compile.c
+++ b/compile.c
@@ -3248,11 +3248,17 @@ CORD compile(env_t *env, ast_t *ast)
}
case Reduction: {
auto reduction = Match(ast, Reduction);
- binop_e op = reduction->combination->tag == BinaryOp ? Match(reduction->combination, BinaryOp)->op : BINOP_UNKNOWN;
+ binop_e op = reduction->op;
type_t *iter_t = get_type(env, reduction->iter);
type_t *item_t = get_iterated_type(iter_t);
if (!item_t) code_err(reduction->iter, "I couldn't figure out how to iterate over this type: %T", iter_t);
+
+ static int64_t next_id = 1;
+ ast_t *item = FakeAST(Var, heap_strf("$it%ld", next_id++));
+ ast_t *body = FakeAST(InlineCCode, .code="{}"); // placeholder
+ ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body);
+ env_t *body_scope = for_scope(env, loop);
if (op == BINOP_EQ || op == BINOP_NE || op == BINOP_LT || op == BINOP_LE || op == BINOP_GT || op == BINOP_GE) {
// Chained comparisons like ==, <, etc.
CORD code = CORD_all(
@@ -3260,26 +3266,62 @@ CORD compile(env_t *env, ast_t *ast)
compile_declaration(item_t, "prev"), ";\n"
"OptionalBool_t result = NULL_BOOL;\n"
);
- env_t *scope = fresh_scope(env);
- set_binding(scope, "$reduction", new(binding_t, .type=item_t, .code="prev"));
- ast_t *item = FakeAST(Var, "$iter_value");
- ast_t *body = FakeAST(InlineCCode, .code="{}"); // placeholder
- ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body);
- env_t *body_scope = for_scope(scope, loop);
+ ast_t *comparison = WrapAST(ast, BinaryOp, .op=op, .lhs=FakeAST(InlineCCode, .code="prev", .type=item_t), .rhs=item);
body->__data.InlineCCode.code = CORD_all(
"if (result == NULL_BOOL) {\n"
" prev = ", compile(body_scope, item), ";\n"
" result = yes;\n"
"} else {\n"
- " if (", compile(body_scope, reduction->combination), ") {\n",
+ " if (", compile(body_scope, comparison), ") {\n",
" prev = ", compile(body_scope, item), ";\n",
" } else {\n"
" result = no;\n",
" break;\n",
" }\n",
"}\n");
- code = CORD_all(code, compile_statement(scope, loop), "\nresult;})");
+ code = CORD_all(code, compile_statement(env, loop), "\nresult;})");
+ return code;
+ } else if (op == BINOP_MIN || op == BINOP_MAX) {
+ // Min/max:
+ const char *superlative = op == BINOP_MIN ? "min" : "max";
+ CORD code = CORD_all(
+ "({ // Reduction:\n",
+ compile_declaration(item_t, superlative), ";\n"
+ "Bool_t has_value = no;\n"
+ );
+
+ CORD item_code = compile(body_scope, item);
+ binop_e cmp_op = op == BINOP_MIN ? BINOP_LT : BINOP_GT;
+ if (reduction->key) {
+ env_t *key_scope = fresh_scope(env);
+ set_binding(key_scope, "$", new(binding_t, .type=item_t, .code=item_code));
+ type_t *key_type = get_type(key_scope, reduction->key);
+ const char *superlative_key = op == BINOP_MIN ? "min_key" : "max_key";
+ code = CORD_all(code, compile_declaration(key_type, superlative_key), ";\n");
+
+ ast_t *comparison = WrapAST(ast, BinaryOp, .op=cmp_op,
+ .lhs=FakeAST(InlineCCode, .code="key", .type=key_type),
+ .rhs=FakeAST(InlineCCode, .code=superlative_key, .type=key_type));
+ body->__data.InlineCCode.code = CORD_all(
+ compile_declaration(key_type, "key"), " = ", compile(key_scope, reduction->key), ";\n",
+ "if (!has_value || ", compile(body_scope, comparison), ") {\n"
+ " ", superlative, " = ", compile(body_scope, item), ";\n"
+ " ", superlative_key, " = key;\n"
+ " has_value = yes;\n"
+ "}\n");
+ } else {
+ ast_t *comparison = WrapAST(ast, BinaryOp, .op=cmp_op, .lhs=item, .rhs=FakeAST(InlineCCode, .code=superlative, .type=item_t));
+ body->__data.InlineCCode.code = CORD_all(
+ "if (!has_value || ", compile(body_scope, comparison), ") {\n"
+ " ", superlative, " = ", compile(body_scope, item), ";\n"
+ " has_value = yes;\n"
+ "}\n");
+ }
+
+
+ code = CORD_all(code, compile_statement(env, loop), "\nhas_value ? ", promote_to_optional(item_t, superlative),
+ " : ", compile_null(item_t), ";})");
return code;
} else {
// Accumulator-style reductions like +, ++, *, etc.
@@ -3288,12 +3330,6 @@ CORD compile(env_t *env, ast_t *ast)
compile_declaration(item_t, "reduction"), ";\n"
"Bool_t has_value = no;\n"
);
- env_t *scope = fresh_scope(env);
- set_binding(scope, "$reduction", new(binding_t, .type=item_t, .code="reduction"));
- ast_t *item = FakeAST(Var, "$iter_value");
- ast_t *body = FakeAST(InlineCCode, .code="{}"); // placeholder
- ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body);
- env_t *body_scope = for_scope(scope, loop);
// For the special case of (or)/(and), we need to early out if we can:
CORD early_out = CORD_EMPTY;
@@ -3312,16 +3348,17 @@ CORD compile(env_t *env, ast_t *ast)
early_out = CORD_all("if (!", check_null(item_t, "reduction"), ") break;");
}
+ ast_t *combination = WrapAST(ast, BinaryOp, .op=op, .lhs=FakeAST(InlineCCode, .code="reduction", .type=item_t), .rhs=item);
body->__data.InlineCCode.code = CORD_all(
"if (!has_value) {\n"
" reduction = ", compile(body_scope, item), ";\n"
" has_value = yes;\n"
"} else {\n"
- " reduction = ", compile(body_scope, reduction->combination), ";\n",
+ " reduction = ", compile(body_scope, combination), ";\n",
early_out,
"}\n");
- code = CORD_all(code, compile_statement(scope, loop), "\nhas_value ? ", promote_to_optional(item_t, "reduction"),
+ code = CORD_all(code, compile_statement(env, loop), "\nhas_value ? ", promote_to_optional(item_t, "reduction"),
" : ", compile_null(item_t), ";})");
return code;
}