1 // This file defines how to compile reductions like `(+: nums)`
5 #include "../environment.h"
6 #include "../stdlib/text.h"
7 #include "../stdlib/util.h"
8 #include "../typecheck.h"
9 #include "compilation.h"
12 Text_t compile_reduction(env_t *env, ast_t *ast) {
13 DeclareMatch(reduction, ast, Reduction);
14 ast_e op = reduction->op;
15 const char *op_str = binop_info[op].operator;
17 type_t *iter_t = get_type(env, reduction->iter);
18 type_t *item_t = get_iterated_type(iter_t);
20 code_err(reduction->iter, "I couldn't figure out how to iterate over this type: ", type_to_text(iter_t));
22 static int64_t next_id = 1;
23 ast_t *item = FakeAST(Var, String("$it", next_id++));
24 ast_t *body = LiteralCode(Text("{}")); // placeholder
25 ast_t *loop = FakeAST(For, .vars = new (ast_list_t, .ast = item), .iter = reduction->iter, .body = body);
26 env_t *body_scope = for_scope(env, loop);
27 if (op == Equals || op == NotEquals || op == LessThan || op == LessThanOrEquals || op == GreaterThan
28 || op == GreaterThanOrEquals) {
29 // Chained comparisons like ==, <, etc.
30 type_t *item_value_type = item_t;
31 ast_t *item_value = item;
33 set_binding(body_scope, op_str, item_t, compile(body_scope, item));
34 item_value = reduction->key;
35 item_value_type = get_type(body_scope, reduction->key);
38 Text_t code = Texts("({ // Reduction:\n", compile_declaration(item_value_type, Text("prev")),
40 "OptionalBool_t result = NONE_BOOL;\n");
43 new (ast_t, .file = ast->file, .start = ast->start, .end = ast->end, .tag = op,
44 .__data.Plus.lhs = LiteralCode(Text("prev"), .type = item_value_type), .__data.Plus.rhs = item_value);
45 body->__data.InlineCCode.chunks = new (
46 ast_list_t, .ast = FakeAST(TextLiteral, Texts("if (result == NONE_BOOL) {\n"
48 compile(body_scope, item_value),
53 compile(body_scope, comparison), ") {\n",
54 " prev = ", compile(body_scope, item_value), ";\n",
57 " break;\n", " }\n", "}\n")));
58 code = Texts(code, compile_statement(env, loop), "\nresult;})");
60 } else if (op == Min || op == Max) {
62 Text_t superlative = op == Min ? Text("min") : Text("max");
63 Text_t code = Texts("({ // Reduction:\n", compile_declaration(item_t, superlative),
65 "Bool_t has_value = no;\n");
67 Text_t item_code = compile(body_scope, item);
68 ast_e cmp_op = op == Min ? LessThan : GreaterThan;
70 env_t *key_scope = fresh_scope(env);
71 set_binding(key_scope, op_str, item_t, item_code);
72 type_t *key_type = get_type(key_scope, reduction->key);
73 Text_t superlative_key = op == Min ? Text("min_key") : Text("max_key");
74 code = Texts(code, compile_declaration(key_type, superlative_key), ";\n");
76 ast_t *comparison = new (ast_t, .file = ast->file, .start = ast->start, .end = ast->end, .tag = cmp_op,
77 .__data.Plus.lhs = LiteralCode(Text("key"), .type = key_type),
78 .__data.Plus.rhs = LiteralCode(superlative_key, .type = key_type));
80 body->__data.InlineCCode.chunks = new (
81 ast_list_t, .ast = FakeAST(TextLiteral, Texts(compile_declaration(key_type, Text("key")), " = ",
82 compile(key_scope, reduction->key), ";\n",
83 "if (!has_value || ", compile(body_scope, comparison),
86 superlative, " = ", compile(body_scope, item),
95 new (ast_t, .file = ast->file, .start = ast->start, .end = ast->end, .tag = cmp_op,
96 .__data.Plus.lhs = item, .__data.Plus.rhs = LiteralCode(superlative, .type = item_t));
97 body->__data.InlineCCode.chunks = new (
98 ast_list_t, .ast = FakeAST(TextLiteral, Texts("if (!has_value || ", compile(body_scope, comparison),
101 superlative, " = ", compile(body_scope, item),
103 " has_value = yes;\n"
107 code = Texts(code, compile_statement(env, loop), "\nhas_value ? ", promote_to_optional(item_t, superlative),
108 " : ", compile_none(item_t), ";})");
111 // Accumulator-style reductions like +, ++, *, etc.
112 type_t *reduction_type = Match(get_type(env, ast), OptionalType)->type;
113 ast_t *item_value = item;
114 if (reduction->key) {
115 set_binding(body_scope, op_str, item_t, compile(body_scope, item));
116 item_value = reduction->key;
119 Text_t code = Texts("({ // Reduction:\n", compile_declaration(reduction_type, Text("reduction")),
121 "Bool_t has_value = no;\n");
123 // For the special case of (or)/(and), we need to early out if we
125 Text_t early_out = EMPTY_TEXT;
127 if (reduction_type->tag != IntType || Match(reduction_type, IntType)->bits != TYPE_IBITS32)
128 code_err(ast, "<> reductions are only supported for Int32 "
130 } else if (op == And) {
131 if (reduction_type->tag == BoolType) early_out = Text("if (!reduction) break;");
132 else if (reduction_type->tag == OptionalType)
133 early_out = Texts("if (", check_none(reduction_type, Text("reduction")), ") break;");
134 } else if (op == Or) {
135 if (reduction_type->tag == BoolType) early_out = Text("if (reduction) break;");
136 else if (reduction_type->tag == OptionalType)
137 early_out = Texts("if (!", check_none(reduction_type, Text("reduction")), ") break;");
140 ast_t *combination = new (ast_t, .file = ast->file, .start = ast->start, .end = ast->end, .tag = op,
141 .__data.Plus.lhs = LiteralCode(Text("reduction"), .type = reduction_type),
142 .__data.Plus.rhs = item_value);
143 body->__data.InlineCCode.chunks = new (
144 ast_list_t, .ast = FakeAST(TextLiteral, Texts("if (!has_value) {\n"
146 compile(body_scope, item_value),
148 " has_value = yes;\n"
151 compile(body_scope, combination), ";\n", early_out, "}\n")));
154 Texts(code, compile_statement(env, loop), "\nhas_value ? ",
155 promote_to_optional(reduction_type, Text("reduction")), " : ", compile_none(reduction_type), ";})");