Improve reductions so they work better nested and also have bespoke code
optimized for min/max and argmin/argmax.
This commit is contained in:
parent
9c842201f3
commit
5d35f28633
4
ast.c
4
ast.c
@ -142,8 +142,8 @@ CORD ast_to_xml(ast_t *ast)
|
||||
T(Repeat, "<Repeat>%r</Repeat>", optional_tagged("body", data.body))
|
||||
T(If, "<If>%r%r%r</If>", optional_tagged("condition", data.condition), optional_tagged("body", data.body), optional_tagged("else", data.else_body))
|
||||
T(When, "<When><subject>%r</subject>%r%r</When>", ast_to_xml(data.subject), when_clauses_to_xml(data.clauses), optional_tagged("else", data.else_body))
|
||||
T(Reduction, "<Reduction>%r%r</Reduction>", optional_tagged("iterable", data.iter),
|
||||
optional_tagged("combination", data.combination))
|
||||
T(Reduction, "<Reduction op=%r%r>%r</Reduction>", xml_escape(OP_NAMES[data.op]), optional_tagged("key", data.key),
|
||||
optional_tagged("iterable", data.iter))
|
||||
T(Skip, "<Skip>%r</Skip>", data.target)
|
||||
T(Stop, "<Stop>%r</Stop>", data.target)
|
||||
T(PrintStatement, "<PrintStatement>%r</PrintStatement>", ast_list_to_xml(data.to_print))
|
||||
|
3
ast.h
3
ast.h
@ -265,7 +265,8 @@ struct ast_s {
|
||||
ast_t *else_body;
|
||||
} When;
|
||||
struct {
|
||||
ast_t *iter, *combination;
|
||||
ast_t *iter, *key;
|
||||
binop_e op;
|
||||
} Reduction;
|
||||
struct {
|
||||
const char *target;
|
||||
|
71
compile.c
71
compile.c
@ -3248,11 +3248,17 @@ CORD compile(env_t *env, ast_t *ast)
|
||||
}
|
||||
case Reduction: {
|
||||
auto reduction = Match(ast, Reduction);
|
||||
binop_e op = reduction->combination->tag == BinaryOp ? Match(reduction->combination, BinaryOp)->op : BINOP_UNKNOWN;
|
||||
binop_e op = reduction->op;
|
||||
|
||||
type_t *iter_t = get_type(env, reduction->iter);
|
||||
type_t *item_t = get_iterated_type(iter_t);
|
||||
if (!item_t) code_err(reduction->iter, "I couldn't figure out how to iterate over this type: %T", iter_t);
|
||||
|
||||
static int64_t next_id = 1;
|
||||
ast_t *item = FakeAST(Var, heap_strf("$it%ld", next_id++));
|
||||
ast_t *body = FakeAST(InlineCCode, .code="{}"); // placeholder
|
||||
ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body);
|
||||
env_t *body_scope = for_scope(env, loop);
|
||||
if (op == BINOP_EQ || op == BINOP_NE || op == BINOP_LT || op == BINOP_LE || op == BINOP_GT || op == BINOP_GE) {
|
||||
// Chained comparisons like ==, <, etc.
|
||||
CORD code = CORD_all(
|
||||
@ -3260,26 +3266,62 @@ CORD compile(env_t *env, ast_t *ast)
|
||||
compile_declaration(item_t, "prev"), ";\n"
|
||||
"OptionalBool_t result = NULL_BOOL;\n"
|
||||
);
|
||||
env_t *scope = fresh_scope(env);
|
||||
set_binding(scope, "$reduction", new(binding_t, .type=item_t, .code="prev"));
|
||||
ast_t *item = FakeAST(Var, "$iter_value");
|
||||
ast_t *body = FakeAST(InlineCCode, .code="{}"); // placeholder
|
||||
ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body);
|
||||
env_t *body_scope = for_scope(scope, loop);
|
||||
|
||||
ast_t *comparison = WrapAST(ast, BinaryOp, .op=op, .lhs=FakeAST(InlineCCode, .code="prev", .type=item_t), .rhs=item);
|
||||
body->__data.InlineCCode.code = CORD_all(
|
||||
"if (result == NULL_BOOL) {\n"
|
||||
" prev = ", compile(body_scope, item), ";\n"
|
||||
" result = yes;\n"
|
||||
"} else {\n"
|
||||
" if (", compile(body_scope, reduction->combination), ") {\n",
|
||||
" if (", compile(body_scope, comparison), ") {\n",
|
||||
" prev = ", compile(body_scope, item), ";\n",
|
||||
" } else {\n"
|
||||
" result = no;\n",
|
||||
" break;\n",
|
||||
" }\n",
|
||||
"}\n");
|
||||
code = CORD_all(code, compile_statement(scope, loop), "\nresult;})");
|
||||
code = CORD_all(code, compile_statement(env, loop), "\nresult;})");
|
||||
return code;
|
||||
} else if (op == BINOP_MIN || op == BINOP_MAX) {
|
||||
// Min/max:
|
||||
const char *superlative = op == BINOP_MIN ? "min" : "max";
|
||||
CORD code = CORD_all(
|
||||
"({ // Reduction:\n",
|
||||
compile_declaration(item_t, superlative), ";\n"
|
||||
"Bool_t has_value = no;\n"
|
||||
);
|
||||
|
||||
CORD item_code = compile(body_scope, item);
|
||||
binop_e cmp_op = op == BINOP_MIN ? BINOP_LT : BINOP_GT;
|
||||
if (reduction->key) {
|
||||
env_t *key_scope = fresh_scope(env);
|
||||
set_binding(key_scope, "$", new(binding_t, .type=item_t, .code=item_code));
|
||||
type_t *key_type = get_type(key_scope, reduction->key);
|
||||
const char *superlative_key = op == BINOP_MIN ? "min_key" : "max_key";
|
||||
code = CORD_all(code, compile_declaration(key_type, superlative_key), ";\n");
|
||||
|
||||
ast_t *comparison = WrapAST(ast, BinaryOp, .op=cmp_op,
|
||||
.lhs=FakeAST(InlineCCode, .code="key", .type=key_type),
|
||||
.rhs=FakeAST(InlineCCode, .code=superlative_key, .type=key_type));
|
||||
body->__data.InlineCCode.code = CORD_all(
|
||||
compile_declaration(key_type, "key"), " = ", compile(key_scope, reduction->key), ";\n",
|
||||
"if (!has_value || ", compile(body_scope, comparison), ") {\n"
|
||||
" ", superlative, " = ", compile(body_scope, item), ";\n"
|
||||
" ", superlative_key, " = key;\n"
|
||||
" has_value = yes;\n"
|
||||
"}\n");
|
||||
} else {
|
||||
ast_t *comparison = WrapAST(ast, BinaryOp, .op=cmp_op, .lhs=item, .rhs=FakeAST(InlineCCode, .code=superlative, .type=item_t));
|
||||
body->__data.InlineCCode.code = CORD_all(
|
||||
"if (!has_value || ", compile(body_scope, comparison), ") {\n"
|
||||
" ", superlative, " = ", compile(body_scope, item), ";\n"
|
||||
" has_value = yes;\n"
|
||||
"}\n");
|
||||
}
|
||||
|
||||
|
||||
code = CORD_all(code, compile_statement(env, loop), "\nhas_value ? ", promote_to_optional(item_t, superlative),
|
||||
" : ", compile_null(item_t), ";})");
|
||||
return code;
|
||||
} else {
|
||||
// Accumulator-style reductions like +, ++, *, etc.
|
||||
@ -3288,12 +3330,6 @@ CORD compile(env_t *env, ast_t *ast)
|
||||
compile_declaration(item_t, "reduction"), ";\n"
|
||||
"Bool_t has_value = no;\n"
|
||||
);
|
||||
env_t *scope = fresh_scope(env);
|
||||
set_binding(scope, "$reduction", new(binding_t, .type=item_t, .code="reduction"));
|
||||
ast_t *item = FakeAST(Var, "$iter_value");
|
||||
ast_t *body = FakeAST(InlineCCode, .code="{}"); // placeholder
|
||||
ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body);
|
||||
env_t *body_scope = for_scope(scope, loop);
|
||||
|
||||
// For the special case of (or)/(and), we need to early out if we can:
|
||||
CORD early_out = CORD_EMPTY;
|
||||
@ -3312,16 +3348,17 @@ CORD compile(env_t *env, ast_t *ast)
|
||||
early_out = CORD_all("if (!", check_null(item_t, "reduction"), ") break;");
|
||||
}
|
||||
|
||||
ast_t *combination = WrapAST(ast, BinaryOp, .op=op, .lhs=FakeAST(InlineCCode, .code="reduction", .type=item_t), .rhs=item);
|
||||
body->__data.InlineCCode.code = CORD_all(
|
||||
"if (!has_value) {\n"
|
||||
" reduction = ", compile(body_scope, item), ";\n"
|
||||
" has_value = yes;\n"
|
||||
"} else {\n"
|
||||
" reduction = ", compile(body_scope, reduction->combination), ";\n",
|
||||
" reduction = ", compile(body_scope, combination), ";\n",
|
||||
early_out,
|
||||
"}\n");
|
||||
|
||||
code = CORD_all(code, compile_statement(scope, loop), "\nhas_value ? ", promote_to_optional(item_t, "reduction"),
|
||||
code = CORD_all(code, compile_statement(env, loop), "\nhas_value ? ", promote_to_optional(item_t, "reduction"),
|
||||
" : ", compile_null(item_t), ";})");
|
||||
return code;
|
||||
}
|
||||
|
14
parse.c
14
parse.c
@ -954,15 +954,12 @@ PARSER(parse_reduction) {
|
||||
if (!match(&pos, "(")) return NULL;
|
||||
|
||||
whitespace(&pos);
|
||||
const char *combo_start = pos;
|
||||
binop_e op = match_binary_operator(&pos);
|
||||
if (op == BINOP_UNKNOWN) return NULL;
|
||||
|
||||
ast_t *combination;
|
||||
ast_t *lhs = NewAST(ctx->file, pos, pos, Var, .name="$reduction");
|
||||
ast_t *rhs = NewAST(ctx->file, pos, pos, Var, .name="$iter_value");
|
||||
ast_t *key = NULL;
|
||||
if (op == BINOP_MIN || op == BINOP_MAX) {
|
||||
ast_t *key = NewAST(ctx->file, pos, pos, Var, .name="$");
|
||||
key = NewAST(ctx->file, pos, pos, Var, .name="$");
|
||||
for (bool progress = true; progress; ) {
|
||||
ast_t *new_term;
|
||||
progress = (false
|
||||
@ -977,11 +974,6 @@ PARSER(parse_reduction) {
|
||||
}
|
||||
if (key->tag == Var) key = NULL;
|
||||
else pos = key->end;
|
||||
combination = op == BINOP_MIN ?
|
||||
NewAST(ctx->file, combo_start, pos, Min, .lhs=lhs, .rhs=rhs, .key=key)
|
||||
: NewAST(ctx->file, combo_start, pos, Max, .lhs=lhs, .rhs=rhs, .key=key);
|
||||
} else {
|
||||
combination = NewAST(ctx->file, combo_start, pos, BinaryOp, .op=op, .lhs=lhs, .rhs=rhs);
|
||||
}
|
||||
|
||||
whitespace(&pos);
|
||||
@ -999,7 +991,7 @@ PARSER(parse_reduction) {
|
||||
whitespace(&pos);
|
||||
expect_closing(ctx, &pos, ")", "I wasn't able to parse the rest of this reduction");
|
||||
|
||||
return NewAST(ctx->file, start, pos, Reduction, .iter=iter, .combination=combination);
|
||||
return NewAST(ctx->file, start, pos, Reduction, .iter=iter, .op=op, .key=key);
|
||||
}
|
||||
|
||||
ast_t *parse_index_suffix(parse_ctx_t *ctx, ast_t *lhs) {
|
||||
|
@ -1019,11 +1019,9 @@ type_t *get_type(env_t *env, ast_t *ast)
|
||||
auto reduction = Match(ast, Reduction);
|
||||
type_t *iter_t = get_type(env, reduction->iter);
|
||||
|
||||
if (reduction->combination && reduction->combination->tag == BinaryOp) {
|
||||
binop_e op = Match(reduction->combination, BinaryOp)->op;
|
||||
if (op == BINOP_EQ || op == BINOP_NE || op == BINOP_LT || op == BINOP_LE || op == BINOP_GT || op == BINOP_GE)
|
||||
return Type(OptionalType, .type=Type(BoolType));
|
||||
}
|
||||
if (reduction->op == BINOP_EQ || reduction->op == BINOP_NE || reduction->op == BINOP_LT
|
||||
|| reduction->op == BINOP_LE || reduction->op == BINOP_GT || reduction->op == BINOP_GE)
|
||||
return Type(OptionalType, .type=Type(BoolType));
|
||||
|
||||
type_t *iterated = get_iterated_type(iter_t);
|
||||
if (!iterated)
|
||||
|
Loading…
Reference in New Issue
Block a user