aboutsummaryrefslogtreecommitdiff
path: root/src/compile/reductions.c
blob: 1652384c80ac16237af6eb0720c643ccf560aefd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
// This file defines how to compile reductions like `(+: nums)`

#include "../ast.h"
#include "../config.h"
#include "../environment.h"
#include "../stdlib/text.h"
#include "../stdlib/util.h"
#include "../typecheck.h"
#include "compilation.h"

public
Text_t compile_reduction(env_t *env, ast_t *ast) {
    DeclareMatch(reduction, ast, Reduction);
    ast_e op = reduction->op;
    const char *op_str = binop_operator(op);

    type_t *iter_t = get_type(env, reduction->iter);
    type_t *item_t = get_iterated_type(iter_t);
    if (!item_t)
        code_err(reduction->iter, "I couldn't figure out how to iterate over this type: ", type_to_str(iter_t));

    static int64_t next_id = 1;
    ast_t *item = FakeAST(Var, String("$it", next_id++));
    ast_t *body = LiteralCode(Text("{}")); // placeholder
    ast_t *loop = FakeAST(For, .vars = new (ast_list_t, .ast = item), .iter = reduction->iter, .body = body);
    env_t *body_scope = for_scope(env, loop);
    if (op == Equals || op == NotEquals || op == LessThan || op == LessThanOrEquals || op == GreaterThan
        || op == GreaterThanOrEquals) {
        // Chained comparisons like ==, <, etc.
        type_t *item_value_type = item_t;
        ast_t *item_value = item;
        if (reduction->key) {
            set_binding(body_scope, op_str, item_t, compile(body_scope, item));
            item_value = reduction->key;
            item_value_type = get_type(body_scope, reduction->key);
        }

        Text_t code = Texts("({ // Reduction:\n", compile_declaration(item_value_type, Text("prev")),
                            ";\n"
                            "OptionalBool_t result = NONE_BOOL;\n");

        ast_t *comparison =
            new (ast_t, .file = ast->file, .start = ast->start, .end = ast->end, .tag = op,
                 .__data.Plus.lhs = LiteralCode(Text("prev"), .type = item_value_type), .__data.Plus.rhs = item_value);
        body->__data.InlineCCode.chunks = new (
            ast_list_t, .ast = FakeAST(TextLiteral, Texts("if (result == NONE_BOOL) {\n"
                                                          "    prev = ",
                                                          compile(body_scope, item_value),
                                                          ";\n"
                                                          "    result = yes;\n"
                                                          "} else {\n"
                                                          "    if (",
                                                          compile(body_scope, comparison), ") {\n",
                                                          "        prev = ", compile(body_scope, item_value), ";\n",
                                                          "    } else {\n"
                                                          "        result = no;\n",
                                                          "        break;\n", "    }\n", "}\n")));
        code = Texts(code, compile_statement(env, loop), "\nresult;})");
        return code;
    } else if (op == Min || op == Max) {
        // Min/max:
        Text_t superlative = op == Min ? Text("min") : Text("max");
        Text_t code = Texts("({ // Reduction:\n", compile_declaration(item_t, superlative),
                            ";\n"
                            "Bool_t has_value = no;\n");

        Text_t item_code = compile(body_scope, item);
        ast_e cmp_op = op == Min ? LessThan : GreaterThan;
        if (reduction->key) {
            env_t *key_scope = fresh_scope(env);
            set_binding(key_scope, op_str, item_t, item_code);
            type_t *key_type = get_type(key_scope, reduction->key);
            Text_t superlative_key = op == Min ? Text("min_key") : Text("max_key");
            code = Texts(code, compile_declaration(key_type, superlative_key), ";\n");

            ast_t *comparison = new (ast_t, .file = ast->file, .start = ast->start, .end = ast->end, .tag = cmp_op,
                                     .__data.Plus.lhs = LiteralCode(Text("key"), .type = key_type),
                                     .__data.Plus.rhs = LiteralCode(superlative_key, .type = key_type));

            body->__data.InlineCCode.chunks = new (
                ast_list_t, .ast = FakeAST(TextLiteral, Texts(compile_declaration(key_type, Text("key")), " = ",
                                                              compile(key_scope, reduction->key), ";\n",
                                                              "if (!has_value || ", compile(body_scope, comparison),
                                                              ") {\n"
                                                              "    ",
                                                              superlative, " = ", compile(body_scope, item),
                                                              ";\n"
                                                              "    ",
                                                              superlative_key,
                                                              " = key;\n"
                                                              "    has_value = yes;\n"
                                                              "}\n")));
        } else {
            ast_t *comparison =
                new (ast_t, .file = ast->file, .start = ast->start, .end = ast->end, .tag = cmp_op,
                     .__data.Plus.lhs = item, .__data.Plus.rhs = LiteralCode(superlative, .type = item_t));
            body->__data.InlineCCode.chunks = new (
                ast_list_t, .ast = FakeAST(TextLiteral, Texts("if (!has_value || ", compile(body_scope, comparison),
                                                              ") {\n"
                                                              "    ",
                                                              superlative, " = ", compile(body_scope, item),
                                                              ";\n"
                                                              "    has_value = yes;\n"
                                                              "}\n")));
        }

        code = Texts(code, compile_statement(env, loop), "\nhas_value ? ", promote_to_optional(item_t, superlative),
                     " : ", compile_none(item_t), ";})");
        return code;
    } else {
        // Accumulator-style reductions like +, ++, *, etc.
        type_t *reduction_type = Match(get_type(env, ast), OptionalType)->type;
        ast_t *item_value = item;
        if (reduction->key) {
            set_binding(body_scope, op_str, item_t, compile(body_scope, item));
            item_value = reduction->key;
        }

        Text_t code = Texts("({ // Reduction:\n", compile_declaration(reduction_type, Text("reduction")),
                            ";\n"
                            "Bool_t has_value = no;\n");

        // For the special case of (or)/(and), we need to early out if we
        // can:
        Text_t early_out = EMPTY_TEXT;
        if (op == Compare) {
            if (reduction_type->tag != IntType || Match(reduction_type, IntType)->bits != TYPE_IBITS32)
                code_err(ast, "<> reductions are only supported for Int32 "
                              "values");
        } else if (op == And) {
            if (reduction_type->tag == BoolType) early_out = Text("if (!reduction) break;");
            else if (reduction_type->tag == OptionalType)
                early_out = Texts("if (", check_none(reduction_type, Text("reduction")), ") break;");
        } else if (op == Or) {
            if (reduction_type->tag == BoolType) early_out = Text("if (reduction) break;");
            else if (reduction_type->tag == OptionalType)
                early_out = Texts("if (!", check_none(reduction_type, Text("reduction")), ") break;");
        }

        ast_t *combination = new (ast_t, .file = ast->file, .start = ast->start, .end = ast->end, .tag = op,
                                  .__data.Plus.lhs = LiteralCode(Text("reduction"), .type = reduction_type),
                                  .__data.Plus.rhs = item_value);
        body->__data.InlineCCode.chunks = new (
            ast_list_t, .ast = FakeAST(TextLiteral, Texts("if (!has_value) {\n"
                                                          "    reduction = ",
                                                          compile(body_scope, item_value),
                                                          ";\n"
                                                          "    has_value = yes;\n"
                                                          "} else {\n"
                                                          "    reduction = ",
                                                          compile(body_scope, combination), ";\n", early_out, "}\n")));

        code =
            Texts(code, compile_statement(env, loop), "\nhas_value ? ",
                  promote_to_optional(reduction_type, Text("reduction")), " : ", compile_none(reduction_type), ";})");
        return code;
    }
}