1 // This file defines how to compile loops
7 #include "../environment.h"
8 #include "../stdlib/datatypes.h"
9 #include "../stdlib/integers.h"
10 #include "../stdlib/text.h"
11 #include "../stdlib/util.h"
12 #include "../typecheck.h"
13 #include "compilation.h"
16 Text_t compile_for_loop(env_t *env, ast_t *ast) {
17 DeclareMatch(for_, ast, For);
19 // If we're iterating over a comprehension, that's actually just doing
20 // one loop, we don't need to compile the comprehension as a list
21 // comprehension. This is a common case for reducers like `(+: i*2 for i
22 // in 5)` or `(and) x.is_good() for x in xs`
23 if (for_->iter->tag == Comprehension) {
24 DeclareMatch(comp, for_->iter, Comprehension);
25 ast_t *body = for_->body;
27 if (for_->vars->next) code_err(for_->vars->next->ast, "This is too many variables for iteration");
32 new (ast_list_t, .ast = WrapAST(ast, Declare, .var = for_->vars->ast, .value = comp->expr),
33 .next = body->tag == Block ? Match(body, Block)->statements : new (ast_list_t, .ast = body)));
36 if (comp->filter) body = WrapAST(for_->body, If, .condition = comp->filter, .body = body);
37 ast_t *loop = WrapAST(ast, For, .vars = comp->vars, .iter = comp->iter, .body = body);
38 return compile_statement(env, loop);
41 env_t *body_scope = for_scope(env, ast);
42 loop_ctx_t loop_ctx = (loop_ctx_t){
44 .loop_vars = for_->vars,
45 .deferred = body_scope->deferred,
46 .next = body_scope->loop_ctx,
48 body_scope->loop_ctx = &loop_ctx;
49 // Naked means no enclosing braces:
50 Text_t naked_body = compile_inline_block(body_scope, for_->body);
51 if (loop_ctx.skip_label.length > 0) naked_body = Texts(naked_body, "\n", loop_ctx.skip_label, ": continue;");
52 Text_t stop = loop_ctx.stop_label.length > 0 ? Texts("\n", loop_ctx.stop_label, ":;") : EMPTY_TEXT;
54 // Special case for improving performance for numeric iteration:
55 if (for_->iter->tag == MethodCall && streq(Match(for_->iter, MethodCall)->name, "to")
56 && is_int_type(get_type(env, Match(for_->iter, MethodCall)->self))) {
57 // TODO: support other integer types
58 arg_ast_t *args = Match(for_->iter, MethodCall)->args;
59 if (!args) code_err(for_->iter, "to() needs at least one argument");
61 type_t *int_type = get_type(env, Match(for_->iter, MethodCall)->self);
62 type_t *step_type = int_type->tag == ByteType ? Type(IntType, .bits = TYPE_IBITS8) : int_type;
64 Text_t last = EMPTY_TEXT, step = EMPTY_TEXT, optional_step = EMPTY_TEXT;
65 if (!args->name || streq(args->name, "last")) {
66 last = compile_to_type(env, args->value, int_type);
68 if (args->next->name && !streq(args->next->name, "step"))
69 code_err(args->next->value, "Invalid argument name: ", args->next->name);
70 if (get_type(env, args->next->value)->tag == OptionalType)
71 optional_step = compile_to_type(env, args->next->value, Type(OptionalType, step_type));
72 else step = compile_to_type(env, args->next->value, step_type);
74 } else if (streq(args->name, "step")) {
75 if (get_type(env, args->value)->tag == OptionalType)
76 optional_step = compile_to_type(env, args->value, Type(OptionalType, step_type));
77 else step = compile_to_type(env, args->value, step_type);
79 if (args->next->name && !streq(args->next->name, "last"))
80 code_err(args->next->value, "Invalid argument name: ", args->next->name);
81 last = compile_to_type(env, args->next->value, int_type);
85 if (last.length == 0) code_err(for_->iter, "No `last` argument was given");
87 Text_t type_code = compile_type(int_type);
88 Text_t value = for_->vars ? compile(body_scope, for_->vars->ast) : Text("i");
89 if (int_type->tag == BigIntType) {
90 if (optional_step.length > 0)
91 step = Texts("({ OptionalInt_t maybe_step = ", optional_step,
92 "; maybe_step->small == 0 ? "
93 "(Int$compare_value(last, first) >= 0 "
94 "? I_small(1) : I_small(-1)) : (Int_t)maybe_step; "
96 else if (step.length == 0)
97 step = Text("Int$compare_value(last, first) >= 0 ? "
98 "I_small(1) : I_small(-1)");
99 return Texts("for (", type_code, " first = ", compile(env, Match(for_->iter, MethodCall)->self), ", ",
100 value, " = first, last = ", last, ", step = ", step,
102 "Int$compare_value(",
103 value, ", last) != Int$compare_value(step, I_small(0)); ", value, " = Int$plus(", value,
106 naked_body, "}", stop);
108 if (optional_step.length > 0)
109 step = Texts("({ ", compile_type(Type(OptionalType, step_type)), " maybe_step = ", optional_step,
111 "maybe_step.is_none ? (",
112 type_code, ")(last >= first ? 1 : -1) : maybe_step.value; })");
113 else if (step.length == 0) step = Texts("(", type_code, ")(last >= first ? 1 : -1)");
114 return Texts("for (", type_code, " first = ", compile(env, Match(for_->iter, MethodCall)->self), ", ",
115 value, " = first, last = ", last, ", step = ", step, "; (", compile_type(step_type),
116 ")step > 0 ? ", value, " <= last : ", value, " >= last; ", value,
119 naked_body, "}", stop);
121 } else if (for_->iter->tag == MethodCall && streq(Match(for_->iter, MethodCall)->name, "onward")
122 && get_type(env, Match(for_->iter, MethodCall)->self)->tag == BigIntType) {
123 // Special case for Int.onward()
124 arg_ast_t *args = Match(for_->iter, MethodCall)->args;
126 new (arg_t, .name = "step", .type = INT_TYPE, .default_val = FakeAST(Int, .str = "1"), .next = NULL);
127 Text_t step = compile_arguments(env, for_->iter, arg_spec, args);
128 Text_t value = for_->vars ? compile(body_scope, for_->vars->ast) : Text("i");
129 return Texts("for (Int_t ", value, " = ", compile(env, Match(for_->iter, MethodCall)->self), ", ",
130 "step = ", step, "; ; ", value, " = Int$plus(", value,
133 naked_body, "}", stop);
136 type_t *iter_t = get_type(env, for_->iter);
137 type_t *iter_value_t = value_type(iter_t);
139 switch (iter_value_t->tag) {
141 type_t *item_t = Match(iter_value_t, ListType)->item_type;
142 Text_t index = EMPTY_TEXT;
143 Text_t value = EMPTY_TEXT;
145 if (for_->vars->next) {
146 if (for_->vars->next->next)
147 code_err(for_->vars->next->next->ast, "This is too many variables for this loop");
149 index = compile(body_scope, for_->vars->ast);
150 value = compile(body_scope, for_->vars->next->ast);
152 value = compile(body_scope, for_->vars->ast);
156 Text_t loop = EMPTY_TEXT;
157 loop = Texts(loop, "for (int64_t i = 1; i <= iterating.length; ++i)");
159 if (index.length > 0) naked_body = Texts("Int_t ", index, " = I(i);\n", naked_body);
161 if (value.length > 0) {
162 loop = Texts(loop, "{\n", compile_declaration(item_t, value), " = *(", compile_type(item_t),
163 "*)(iterating.data + (i-1)*iterating.stride);\n", naked_body, "\n}");
165 loop = Texts(loop, "{\n", naked_body, "\n}");
169 loop = Texts("if (iterating.length > 0) {\n", loop, "\n} else ", compile_statement(env, for_->empty));
171 if (iter_t->tag == PointerType) {
174 compile_to_pointer_depth(env, for_->iter, 1, false),
176 "\nLIST_INCREF(*ptr);\n"
177 "List_t iterating = *ptr;\n",
179 "\nLIST_DECREF(*ptr);\n"
184 "List_t iterating = ",
185 compile_to_pointer_depth(env, for_->iter, 0, false), ";\n", loop, stop, "}\n");
190 Text_t loop = Text("for (int64_t i = 0; i < (int64_t)iterating.length; ++i) {\n");
192 Text_t key = compile(body_scope, for_->vars->ast);
193 type_t *key_t = Match(iter_value_t, TableType)->key_type;
194 loop = Texts(loop, compile_declaration(key_t, key), " = *(", compile_type(key_t), "*)(",
195 "iterating.data + i*iterating.stride);\n");
197 if (for_->vars->next) {
198 if (for_->vars->next->next)
199 code_err(for_->vars->next->next->ast, "This is too many variables for this loop");
201 type_t *value_t = Match(iter_value_t, TableType)->value_type;
202 Text_t value = compile(body_scope, for_->vars->next->ast);
203 Text_t value_offset = Texts("offsetof(struct { ", compile_declaration(key_t, Text("k")), "; ",
204 compile_declaration(value_t, Text("v")), "; }, v)");
205 loop = Texts(loop, compile_declaration(value_t, value), " = *(", compile_type(value_t), "*)(",
206 "iterating.data + i*iterating.stride + ", value_offset, ");\n");
210 loop = Texts(loop, naked_body, "\n}");
213 loop = Texts("if (iterating.length > 0) {\n", loop, "\n} else ", compile_statement(env, for_->empty));
216 if (iter_t->tag == PointerType) {
217 loop = Texts("{\n", "Table_t *t = ", compile_to_pointer_depth(env, for_->iter, 1, false),
219 "LIST_INCREF(t->entries);\n"
220 "List_t iterating = t->entries;\n",
222 "LIST_DECREF(t->entries);\n"
225 loop = Texts("{\n", "List_t iterating = (", compile_to_pointer_depth(env, for_->iter, 0, false),
226 ").entries;\n", loop, "}\n");
232 if (for_->iter->tag == Int) {
233 const char *str = Match(for_->iter, Int)->str;
234 Int_t int_val = Int$from_str(str);
235 if (int_val.small == 0) code_err(for_->iter, "Failed to parse this integer");
237 mpz_init_set_int(i, int_val);
238 if (mpz_cmpabs_ui(i, BIGGEST_SMALL_INT) <= 0) n = Text$from_str(mpz_get_str(NULL, 10, i));
241 if (for_->empty && mpz_cmp_si(i, 0) <= 0) {
242 return compile_statement(env, for_->empty);
244 return Texts("for (int64_t i = 1; i <= ", n, "; ++i) {\n",
245 for_->vars ? Texts("\tInt_t ", compile(body_scope, for_->vars->ast), " = I_small(i);\n")
247 "\t", naked_body, "}\n", stop, "\n");
252 n = compile_to_pointer_depth(env, for_->iter, 0, false);
253 Text_t i = for_->vars ? compile(body_scope, for_->vars->ast) : Text("i");
254 Text_t n_var = for_->vars ? Texts("max", i) : Text("n");
260 "if (Int$compare_value(",
264 i, " = I(1); Int$compare_value(", i, ", ", n_var, ") <= 0; ", i, " = Int$plus(", i,
265 ", I(1))) {\n", "\t", naked_body,
268 compile_statement(env, for_->empty), stop,
272 return Texts("for (Int_t ", i, " = I(1), ", n_var, " = ", n, "; Int$compare_value(", i, ", ", n_var,
273 ") <= 0; ", i, " = Int$plus(", i, ", I(1))) {\n", "\t", naked_body, "}\n", stop, "\n");
278 // Iterator function:
279 Text_t code = Text("{\n");
282 if (is_idempotent(for_->iter)) {
283 next_fn = compile_to_pointer_depth(env, for_->iter, 0, false);
285 code = Texts(code, compile_declaration(iter_value_t, Text("next")), " = ",
286 compile_to_pointer_depth(env, for_->iter, 0, false), ";\n");
287 next_fn = Text("next");
290 __typeof(iter_value_t->__data.FunctionType) *fn =
291 iter_value_t->tag == ClosureType ? Match(Match(iter_value_t, ClosureType)->fn, FunctionType)
292 : Match(iter_value_t, FunctionType);
295 if (iter_value_t->tag == ClosureType) {
296 type_t *fn_t = Match(iter_value_t, ClosureType)->fn;
297 arg_t *closure_fn_args = NULL;
298 for (arg_t *arg = Match(fn_t, FunctionType)->args; arg; arg = arg->next)
299 closure_fn_args = new (arg_t, .name = arg->name, .type = arg->type, .default_val = arg->default_val,
300 .next = closure_fn_args);
301 closure_fn_args = new (arg_t, .name = "userdata", .type = Type(PointerType, .pointed = Type(MemoryType)),
302 .next = closure_fn_args);
303 REVERSE_LIST(closure_fn_args);
304 Text_t fn_type_code =
305 compile_type(Type(FunctionType, .args = closure_fn_args, .ret = Match(fn_t, FunctionType)->ret));
306 get_next = Texts("((", fn_type_code, ")", next_fn, ".fn)(", next_fn, ".userdata)");
308 get_next = Texts(next_fn, "()");
311 if (fn->ret->tag == OptionalType) {
312 // Use an optional variable `cur` for each iteration step, which
313 // will be checked for none
314 code = Texts(code, compile_declaration(fn->ret, Text("cur")), ";\n");
315 get_next = Texts("(cur=", get_next, ", !", check_none(fn->ret, Text("cur")), ")");
317 naked_body = Texts(compile_declaration(Match(fn->ret, OptionalType)->type,
318 Texts("_$", Match(for_->vars->ast, Var)->name)),
319 " = ", optional_into_nonnone(fn->ret, Text("cur")), ";\n", naked_body);
322 code = Texts(code, "if (", get_next,
325 naked_body, "\t} while(", get_next,
328 compile_statement(env, for_->empty), "}", stop, "\n}\n");
330 code = Texts(code, "while(", get_next, ") {\n\t", naked_body, "}\n", stop, "\n}\n");
334 naked_body = Texts(compile_declaration(fn->ret, Texts("_$", Match(for_->vars->ast, Var)->name)), " = ",
335 get_next, ";\n", naked_body);
337 naked_body = Texts(get_next, ";\n", naked_body);
340 code_err(for_->empty, "This iteration loop will always have values, "
341 "so this block will never run");
342 code = Texts(code, "for (;;) {\n\t", naked_body, "}\n", stop, "\n}\n");
347 default: code_err(for_->iter, "Iteration is not implemented for type: ", type_to_text(iter_t));
352 Text_t compile_repeat(env_t *env, ast_t *ast) {
353 ast_t *body = Match(ast, Repeat)->body;
354 env_t *scope = fresh_scope(env);
355 loop_ctx_t loop_ctx = (loop_ctx_t){
356 .loop_name = "repeat",
357 .deferred = scope->deferred,
358 .next = env->loop_ctx,
360 scope->loop_ctx = &loop_ctx;
361 Text_t body_code = compile_statement(scope, body);
362 if (loop_ctx.skip_label.length > 0) body_code = Texts(body_code, "\n", loop_ctx.skip_label, ": continue;");
363 Text_t loop = Texts("for (;;) {\n\t", body_code, "\n}");
364 if (loop_ctx.stop_label.length > 0) loop = Texts(loop, "\n", loop_ctx.stop_label, ":;");
369 Text_t compile_while(env_t *env, ast_t *ast) {
370 DeclareMatch(while_, ast, While);
371 env_t *scope = fresh_scope(env);
372 loop_ctx_t loop_ctx = (loop_ctx_t){
373 .loop_name = "while",
374 .deferred = scope->deferred,
375 .next = env->loop_ctx,
377 scope->loop_ctx = &loop_ctx;
378 Text_t body = compile_statement(scope, while_->body);
379 if (loop_ctx.skip_label.length > 0) body = Texts(body, "\n", loop_ctx.skip_label, ": continue;");
381 Texts("while (", while_->condition ? compile(scope, while_->condition) : Text("yes"), ") {\n\t", body, "\n}");
382 if (loop_ctx.stop_label.length > 0) loop = Texts(loop, "\n", loop_ctx.stop_label, ":;");
387 Text_t compile_skip(env_t *env, ast_t *ast) {
388 const char *target = Match(ast, Skip)->target;
389 for (loop_ctx_t *ctx = env->loop_ctx; ctx; ctx = ctx->next) {
390 bool matched = !target || strcmp(target, ctx->loop_name) == 0;
391 for (ast_list_t *var = ctx->loop_vars; var && !matched; var = var ? var->next : NULL)
392 matched = (strcmp(target, Match(var->ast, Var)->name) == 0);
395 if (ctx->skip_label.length == 0) {
396 static int64_t skip_label_count = 1;
397 ctx->skip_label = Texts("skip_", skip_label_count);
400 Text_t code = EMPTY_TEXT;
401 for (deferral_t *deferred = env->deferred; deferred && deferred != ctx->deferred; deferred = deferred->next)
402 code = Texts(code, compile_statement(deferred->defer_env, deferred->block));
403 if (code.length > 0) return Texts("{\n", code, "goto ", ctx->skip_label, ";\n}\n");
404 else return Texts("goto ", ctx->skip_label, ";");
407 if (env->loop_ctx) code_err(ast, "This is not inside any loop");
408 else if (target) code_err(ast, "No loop target named '", target, "' was found");
409 else return Text("continue;");
413 Text_t compile_stop(env_t *env, ast_t *ast) {
414 const char *target = Match(ast, Stop)->target;
415 for (loop_ctx_t *ctx = env->loop_ctx; ctx; ctx = ctx->next) {
416 bool matched = !target || strcmp(target, ctx->loop_name) == 0;
417 for (ast_list_t *var = ctx->loop_vars; var && !matched; var = var ? var->next : var)
418 matched = (strcmp(target, Match(var->ast, Var)->name) == 0);
421 if (ctx->stop_label.length == 0) {
422 static int64_t stop_label_count = 1;
423 ctx->stop_label = Texts("stop_", stop_label_count);
426 Text_t code = EMPTY_TEXT;
427 for (deferral_t *deferred = env->deferred; deferred && deferred != ctx->deferred; deferred = deferred->next)
428 code = Texts(code, compile_statement(deferred->defer_env, deferred->block));
429 if (code.length > 0) return Texts("{\n", code, "goto ", ctx->stop_label, ";\n}\n");
430 else return Texts("goto ", ctx->stop_label, ";");
433 if (env->loop_ctx) code_err(ast, "This is not inside any loop");
434 else if (target) code_err(ast, "No loop target named '", target, "' was found");
435 else return Text("break;");