diff --git a/ast.c b/ast.c index d9ca65a..8626545 100644 --- a/ast.c +++ b/ast.c @@ -142,8 +142,8 @@ CORD ast_to_xml(ast_t *ast) T(Repeat, "%r", optional_tagged("body", data.body)) T(If, "%r%r%r", optional_tagged("condition", data.condition), optional_tagged("body", data.body), optional_tagged("else", data.else_body)) T(When, "%r%r%r", ast_to_xml(data.subject), when_clauses_to_xml(data.clauses), optional_tagged("else", data.else_body)) - T(Reduction, "%r%r", optional_tagged("iterable", data.iter), - optional_tagged("combination", data.combination)) + T(Reduction, "%r", xml_escape(OP_NAMES[data.op]), optional_tagged("key", data.key), + optional_tagged("iterable", data.iter)) T(Skip, "%r", data.target) T(Stop, "%r", data.target) T(PrintStatement, "%r", ast_list_to_xml(data.to_print)) diff --git a/ast.h b/ast.h index 234fa9a..3911f2a 100644 --- a/ast.h +++ b/ast.h @@ -265,7 +265,8 @@ struct ast_s { ast_t *else_body; } When; struct { - ast_t *iter, *combination; + ast_t *iter, *key; + binop_e op; } Reduction; struct { const char *target; diff --git a/compile.c b/compile.c index c6f23d4..372cd91 100644 --- a/compile.c +++ b/compile.c @@ -3248,11 +3248,17 @@ CORD compile(env_t *env, ast_t *ast) } case Reduction: { auto reduction = Match(ast, Reduction); - binop_e op = reduction->combination->tag == BinaryOp ? Match(reduction->combination, BinaryOp)->op : BINOP_UNKNOWN; + binop_e op = reduction->op; type_t *iter_t = get_type(env, reduction->iter); type_t *item_t = get_iterated_type(iter_t); if (!item_t) code_err(reduction->iter, "I couldn't figure out how to iterate over this type: %T", iter_t); + + static int64_t next_id = 1; + ast_t *item = FakeAST(Var, heap_strf("$it%ld", next_id++)); + ast_t *body = FakeAST(InlineCCode, .code="{}"); // placeholder + ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body); + env_t *body_scope = for_scope(env, loop); if (op == BINOP_EQ || op == BINOP_NE || op == BINOP_LT || op == BINOP_LE || op == BINOP_GT || op == BINOP_GE) { // Chained comparisons like ==, <, etc. CORD code = CORD_all( @@ -3260,26 +3266,62 @@ CORD compile(env_t *env, ast_t *ast) compile_declaration(item_t, "prev"), ";\n" "OptionalBool_t result = NULL_BOOL;\n" ); - env_t *scope = fresh_scope(env); - set_binding(scope, "$reduction", new(binding_t, .type=item_t, .code="prev")); - ast_t *item = FakeAST(Var, "$iter_value"); - ast_t *body = FakeAST(InlineCCode, .code="{}"); // placeholder - ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body); - env_t *body_scope = for_scope(scope, loop); + ast_t *comparison = WrapAST(ast, BinaryOp, .op=op, .lhs=FakeAST(InlineCCode, .code="prev", .type=item_t), .rhs=item); body->__data.InlineCCode.code = CORD_all( "if (result == NULL_BOOL) {\n" " prev = ", compile(body_scope, item), ";\n" " result = yes;\n" "} else {\n" - " if (", compile(body_scope, reduction->combination), ") {\n", + " if (", compile(body_scope, comparison), ") {\n", " prev = ", compile(body_scope, item), ";\n", " } else {\n" " result = no;\n", " break;\n", " }\n", "}\n"); - code = CORD_all(code, compile_statement(scope, loop), "\nresult;})"); + code = CORD_all(code, compile_statement(env, loop), "\nresult;})"); + return code; + } else if (op == BINOP_MIN || op == BINOP_MAX) { + // Min/max: + const char *superlative = op == BINOP_MIN ? "min" : "max"; + CORD code = CORD_all( + "({ // Reduction:\n", + compile_declaration(item_t, superlative), ";\n" + "Bool_t has_value = no;\n" + ); + + CORD item_code = compile(body_scope, item); + binop_e cmp_op = op == BINOP_MIN ? BINOP_LT : BINOP_GT; + if (reduction->key) { + env_t *key_scope = fresh_scope(env); + set_binding(key_scope, "$", new(binding_t, .type=item_t, .code=item_code)); + type_t *key_type = get_type(key_scope, reduction->key); + const char *superlative_key = op == BINOP_MIN ? "min_key" : "max_key"; + code = CORD_all(code, compile_declaration(key_type, superlative_key), ";\n"); + + ast_t *comparison = WrapAST(ast, BinaryOp, .op=cmp_op, + .lhs=FakeAST(InlineCCode, .code="key", .type=key_type), + .rhs=FakeAST(InlineCCode, .code=superlative_key, .type=key_type)); + body->__data.InlineCCode.code = CORD_all( + compile_declaration(key_type, "key"), " = ", compile(key_scope, reduction->key), ";\n", + "if (!has_value || ", compile(body_scope, comparison), ") {\n" + " ", superlative, " = ", compile(body_scope, item), ";\n" + " ", superlative_key, " = key;\n" + " has_value = yes;\n" + "}\n"); + } else { + ast_t *comparison = WrapAST(ast, BinaryOp, .op=cmp_op, .lhs=item, .rhs=FakeAST(InlineCCode, .code=superlative, .type=item_t)); + body->__data.InlineCCode.code = CORD_all( + "if (!has_value || ", compile(body_scope, comparison), ") {\n" + " ", superlative, " = ", compile(body_scope, item), ";\n" + " has_value = yes;\n" + "}\n"); + } + + + code = CORD_all(code, compile_statement(env, loop), "\nhas_value ? ", promote_to_optional(item_t, superlative), + " : ", compile_null(item_t), ";})"); return code; } else { // Accumulator-style reductions like +, ++, *, etc. @@ -3288,12 +3330,6 @@ CORD compile(env_t *env, ast_t *ast) compile_declaration(item_t, "reduction"), ";\n" "Bool_t has_value = no;\n" ); - env_t *scope = fresh_scope(env); - set_binding(scope, "$reduction", new(binding_t, .type=item_t, .code="reduction")); - ast_t *item = FakeAST(Var, "$iter_value"); - ast_t *body = FakeAST(InlineCCode, .code="{}"); // placeholder - ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body); - env_t *body_scope = for_scope(scope, loop); // For the special case of (or)/(and), we need to early out if we can: CORD early_out = CORD_EMPTY; @@ -3312,16 +3348,17 @@ CORD compile(env_t *env, ast_t *ast) early_out = CORD_all("if (!", check_null(item_t, "reduction"), ") break;"); } + ast_t *combination = WrapAST(ast, BinaryOp, .op=op, .lhs=FakeAST(InlineCCode, .code="reduction", .type=item_t), .rhs=item); body->__data.InlineCCode.code = CORD_all( "if (!has_value) {\n" " reduction = ", compile(body_scope, item), ";\n" " has_value = yes;\n" "} else {\n" - " reduction = ", compile(body_scope, reduction->combination), ";\n", + " reduction = ", compile(body_scope, combination), ";\n", early_out, "}\n"); - code = CORD_all(code, compile_statement(scope, loop), "\nhas_value ? ", promote_to_optional(item_t, "reduction"), + code = CORD_all(code, compile_statement(env, loop), "\nhas_value ? ", promote_to_optional(item_t, "reduction"), " : ", compile_null(item_t), ";})"); return code; } diff --git a/parse.c b/parse.c index 895b286..3903dd1 100644 --- a/parse.c +++ b/parse.c @@ -954,15 +954,12 @@ PARSER(parse_reduction) { if (!match(&pos, "(")) return NULL; whitespace(&pos); - const char *combo_start = pos; binop_e op = match_binary_operator(&pos); if (op == BINOP_UNKNOWN) return NULL; - ast_t *combination; - ast_t *lhs = NewAST(ctx->file, pos, pos, Var, .name="$reduction"); - ast_t *rhs = NewAST(ctx->file, pos, pos, Var, .name="$iter_value"); + ast_t *key = NULL; if (op == BINOP_MIN || op == BINOP_MAX) { - ast_t *key = NewAST(ctx->file, pos, pos, Var, .name="$"); + key = NewAST(ctx->file, pos, pos, Var, .name="$"); for (bool progress = true; progress; ) { ast_t *new_term; progress = (false @@ -977,11 +974,6 @@ PARSER(parse_reduction) { } if (key->tag == Var) key = NULL; else pos = key->end; - combination = op == BINOP_MIN ? - NewAST(ctx->file, combo_start, pos, Min, .lhs=lhs, .rhs=rhs, .key=key) - : NewAST(ctx->file, combo_start, pos, Max, .lhs=lhs, .rhs=rhs, .key=key); - } else { - combination = NewAST(ctx->file, combo_start, pos, BinaryOp, .op=op, .lhs=lhs, .rhs=rhs); } whitespace(&pos); @@ -999,7 +991,7 @@ PARSER(parse_reduction) { whitespace(&pos); expect_closing(ctx, &pos, ")", "I wasn't able to parse the rest of this reduction"); - return NewAST(ctx->file, start, pos, Reduction, .iter=iter, .combination=combination); + return NewAST(ctx->file, start, pos, Reduction, .iter=iter, .op=op, .key=key); } ast_t *parse_index_suffix(parse_ctx_t *ctx, ast_t *lhs) { diff --git a/typecheck.c b/typecheck.c index 49eba2c..a000393 100644 --- a/typecheck.c +++ b/typecheck.c @@ -1019,11 +1019,9 @@ type_t *get_type(env_t *env, ast_t *ast) auto reduction = Match(ast, Reduction); type_t *iter_t = get_type(env, reduction->iter); - if (reduction->combination && reduction->combination->tag == BinaryOp) { - binop_e op = Match(reduction->combination, BinaryOp)->op; - if (op == BINOP_EQ || op == BINOP_NE || op == BINOP_LT || op == BINOP_LE || op == BINOP_GT || op == BINOP_GE) - return Type(OptionalType, .type=Type(BoolType)); - } + if (reduction->op == BINOP_EQ || reduction->op == BINOP_NE || reduction->op == BINOP_LT + || reduction->op == BINOP_LE || reduction->op == BINOP_GT || reduction->op == BINOP_GE) + return Type(OptionalType, .type=Type(BoolType)); type_t *iterated = get_iterated_type(iter_t); if (!iterated)