From 0df908f55fd7f617be35f7fe7a48f2eee1b19d57 Mon Sep 17 00:00:00 2001 From: Bruce Hill Date: Sat, 9 Nov 2024 17:54:32 -0500 Subject: [PATCH] Support iterating over pointers to collections again --- compile.c | 89 +++++++++++++++++++----------------------- environment.c | 2 +- examples/game/world.tm | 4 +- examples/log/log.tm | 8 ++-- 4 files changed, 47 insertions(+), 56 deletions(-) diff --git a/compile.c b/compile.c index e8c24ea..85aab1a 100644 --- a/compile.c +++ b/compile.c @@ -1125,7 +1125,6 @@ CORD compile_statement(env_t *env, ast_t *ast) return compile_statement(env, loop); } - type_t *iter_t = get_type(env, for_->iter); env_t *body_scope = for_scope(env, ast); loop_ctx_t loop_ctx = (loop_ctx_t){ .loop_name="for", @@ -1140,8 +1139,11 @@ CORD compile_statement(env_t *env, ast_t *ast) naked_body = CORD_all(naked_body, "\n", loop_ctx.skip_label, ": continue;"); CORD stop = loop_ctx.stop_label ? CORD_all("\n", loop_ctx.stop_label, ":;") : CORD_EMPTY; - if (iter_t == RANGE_TYPE) { - CORD range = compile(env, for_->iter); + type_t *iter_t = value_type(get_type(env, for_->iter)); + type_t *iter_value_t = value_type(iter_t); + + if (iter_value_t == RANGE_TYPE) { + CORD range = compile_to_pointer_depth(env, for_->iter, 0, false); CORD value = for_->vars ? compile(body_scope, for_->vars->ast) : "i"; if (for_->empty) code_err(ast, "Ranges are never empty, they always contain at least their starting element"); @@ -1157,9 +1159,9 @@ CORD compile_statement(env_t *env, ast_t *ast) "\n}"); } - switch (iter_t->tag) { + switch (iter_value_t->tag) { case ArrayType: { - type_t *item_t = Match(iter_t, ArrayType)->item_type; + type_t *item_t = Match(iter_value_t, ArrayType)->item_type; CORD index = CORD_EMPTY; CORD value = CORD_EMPTY; if (for_->vars) { @@ -1175,27 +1177,7 @@ CORD compile_statement(env_t *env, ast_t *ast) } CORD loop = CORD_EMPTY; - ast_t *array = for_->iter; - // Micro-optimization: inline the logic for iterating over - // `array:from(i)` and `array:to(i)` because these happen inside - // hot path inner loops and can actually meaningfully affect - // performance: - // if (for_->iter->tag == MethodCall && streq(Match(for_->iter, MethodCall)->name, "to") - // && value_type(get_type(env, Match(for_->iter, MethodCall)->self))->tag == ArrayType) { - // array = Match(for_->iter, MethodCall)->self; - // CORD limit = compile_arguments(env, for_->iter, new(arg_t, .type=INT_TYPE, .name="last"), Match(for_->iter, MethodCall)->args); - // loop = CORD_all(loop, "for (int64_t ", index, " = 1, raw_limit = ", limit, - // ", limit = raw_limit < 0 ? iterating.length + raw_limit + 1 : raw_limit; ", - // index, " <= limit; ++", index, ")"); - // } else if (for_->iter->tag == MethodCall && streq(Match(for_->iter, MethodCall)->name, "from") - // && value_type(get_type(env, Match(for_->iter, MethodCall)->self))->tag == ArrayType) { - // array = Match(for_->iter, MethodCall)->self; - // CORD first = compile_arguments(env, for_->iter, new(arg_t, .type=INT_TYPE, .name="last"), Match(for_->iter, MethodCall)->args); - // loop = CORD_all(loop, "for (int64_t first = ", first, ", ", index, " = MAX(1, first < 1 ? iterating.length + first + 1 : first", "); ", - // index, " <= iterating.length; ++", index, ")"); - // } else { - loop = CORD_all(loop, "for (int64_t i = 1; i <= iterating.length; ++i)"); - // } + loop = CORD_all(loop, "for (int64_t i = 1; i <= iterating.length; ++i)"); if (index != CORD_EMPTY) naked_body = CORD_all("Int_t ", index, " = I(i);\n", naked_body); @@ -1209,23 +1191,23 @@ CORD compile_statement(env_t *env, ast_t *ast) loop = CORD_all(loop, "{\n", naked_body, "\n}"); } - if (can_be_mutated(env, array) && is_idempotent(array)) { - CORD array_code = compile(env, array); + if (for_->empty) + loop = CORD_all("if (iterating.length > 0) {\n", loop, "\n} else ", compile_statement(env, for_->empty)); + + if (iter_t->tag == PointerType) { loop = CORD_all("{\n" - "Array_t iterating = ARRAY_COPY(", array_code, ");\n", + "Array_t *ptr = ", compile_to_pointer_depth(env, for_->iter, 1, false), ";\n" + "\nARRAY_INCREF(*ptr);\n" + "Array_t iterating = *ptr;\n", loop, stop, - "\nARRAY_DECREF(", array_code, ");\n" + "\nARRAY_DECREF(*ptr);\n" "}\n"); - if (for_->empty) - loop = CORD_all("if (", array_code, ".length > 0) {\n", loop, "\n} else ", compile_statement(env, for_->empty)); } else { loop = CORD_all("{\n" - "Array_t iterating = ", compile(env, array), ";\n", - for_->empty ? "if (iterating.length > 0) {\n" : CORD_EMPTY, + "Array_t iterating = ", compile_to_pointer_depth(env, for_->iter, 0, false), ";\n", loop, - for_->empty ? CORD_all("\n} else ", compile_statement(env, for_->empty)) : CORD_EMPTY, stop, "}\n"); } @@ -1234,16 +1216,16 @@ CORD compile_statement(env_t *env, ast_t *ast) case SetType: case TableType: { CORD loop = "for (int64_t i = 0; i < iterating.length; ++i) {\n"; if (for_->vars) { - if (iter_t->tag == SetType) { + if (iter_value_t->tag == SetType) { if (for_->vars->next) code_err(for_->vars->next->ast, "This is too many variables for this loop"); CORD item = compile(body_scope, for_->vars->ast); - type_t *item_type = Match(iter_t, SetType)->item_type; + type_t *item_type = Match(iter_value_t, SetType)->item_type; loop = CORD_all(loop, compile_declaration(item_type, item), " = *(", compile_type(item_type), "*)(", "iterating.data + i*iterating.stride);\n"); } else { CORD key = compile(body_scope, for_->vars->ast); - type_t *key_t = Match(iter_t, TableType)->key_type; + type_t *key_t = Match(iter_value_t, TableType)->key_type; loop = CORD_all(loop, compile_declaration(key_t, key), " = *(", compile_type(key_t), "*)(", "iterating.data + i*iterating.stride);\n"); @@ -1251,7 +1233,7 @@ CORD compile_statement(env_t *env, ast_t *ast) if (for_->vars->next->next) code_err(for_->vars->next->next->ast, "This is too many variables for this loop"); - type_t *value_t = Match(iter_t, TableType)->value_type; + type_t *value_t = Match(iter_value_t, TableType)->value_type; CORD value = compile(body_scope, for_->vars->next->ast); size_t value_offset = type_size(key_t); if (type_align(value_t) > 1 && value_offset % type_align(value_t)) @@ -1268,12 +1250,21 @@ CORD compile_statement(env_t *env, ast_t *ast) loop = CORD_all("if (iterating.length > 0) {\n", loop, "\n} else ", compile_statement(env, for_->empty)); } - if (can_be_mutated(env, for_->iter) && is_idempotent(for_->iter)) { + if (iter_t->tag == PointerType) { loop = CORD_all( "{\n", - "Array_t iterating = ARRAY_COPY((", compile(env, for_->iter), ").entries);\n", + "Table_t *t = ", compile_to_pointer_depth(env, for_->iter, 1, false), ";\n" + "ARRAY_INCREF(t->entries);\n" + "Array_t iterating = t->entries;\n", + loop, + "ARRAY_DECREF(t->entries);\n" + "}\n"); + } else if (can_be_mutated(env, for_->iter)) { + loop = CORD_all( + "{\n", + "Table_t t = ", compile_to_pointer_depth(env, for_->iter, 0, false), ";\n" + "Array_t iterating = t.entries;\n", loop, - "ARRAY_DECREF((", compile(env, for_->iter), ").entries);\n" "}\n"); } else { loop = CORD_all( @@ -1312,13 +1303,13 @@ CORD compile_statement(env_t *env, ast_t *ast) } big_n: - n = compile(env, for_->iter); + n = compile_to_pointer_depth(env, for_->iter, 0, false); CORD i = for_->vars ? compile(body_scope, for_->vars->ast) : "i"; CORD n_var = for_->vars ? CORD_all("max", i) : "n"; if (for_->empty) { return CORD_all( "{\n" - "Int_t ", n_var, " = ", compile(env, for_->iter), ";\n" + "Int_t ", n_var, " = ", n, ";\n" "if (Int$compare_value(", n_var, ", I(0)) > 0) {\n" "for (Int_t ", i, " = I(1); Int$compare_value(", i, ", ", n_var, ") <= 0; ", i, " = Int$plus(", i, ", I(1))) {\n", "\t", naked_body, @@ -1340,17 +1331,17 @@ CORD compile_statement(env_t *env, ast_t *ast) CORD next_fn; if (is_idempotent(for_->iter)) { - next_fn = compile(env, for_->iter); + next_fn = compile_to_pointer_depth(env, for_->iter, 0, false); } else { - code = CORD_all(code, compile_declaration(iter_t, "next"), " = ", compile(env, for_->iter), ";\n"); + code = CORD_all(code, compile_declaration(iter_value_t, "next"), " = ", compile_to_pointer_depth(env, for_->iter, 0, false), ";\n"); next_fn = "next"; } - auto fn = iter_t->tag == ClosureType ? Match(Match(iter_t, ClosureType)->fn, FunctionType) : Match(iter_t, FunctionType); + auto fn = iter_value_t->tag == ClosureType ? Match(Match(iter_value_t, ClosureType)->fn, FunctionType) : Match(iter_value_t, FunctionType); CORD get_next; - if (iter_t->tag == ClosureType) { - type_t *fn_t = Match(iter_t, ClosureType)->fn; + if (iter_value_t->tag == ClosureType) { + type_t *fn_t = Match(iter_value_t, ClosureType)->fn; arg_t *closure_fn_args = NULL; for (arg_t *arg = Match(fn_t, FunctionType)->args; arg; arg = arg->next) closure_fn_args = new(arg_t, .name=arg->name, .type=arg->type, .default_val=arg->default_val, .next=closure_fn_args); diff --git a/environment.c b/environment.c index 6a0483c..b6a9e7d 100644 --- a/environment.c +++ b/environment.c @@ -529,7 +529,7 @@ env_t *fresh_scope(env_t *env) env_t *for_scope(env_t *env, ast_t *ast) { auto for_ = Match(ast, For); - type_t *iter_t = get_type(env, for_->iter); + type_t *iter_t = value_type(get_type(env, for_->iter)); env_t *scope = fresh_scope(env); if (iter_t == RANGE_TYPE) { diff --git a/examples/game/world.tm b/examples/game/world.tm index 4df7f15..a28fe98 100644 --- a/examples/game/world.tm +++ b/examples/game/world.tm @@ -57,11 +57,11 @@ struct World(player:@Player, goal:@Box, boxes:@[@Box], dt_accum=0.0, won=no): # Resolve player overlapping with any boxes: for i in 3: - for b in w.boxes[]: + for b in w.boxes: w.player.pos += STIFFNESS * solve_overlap(w.player.pos, Player.SIZE, b.pos, b.size) func draw(w:@World): - for b in w.boxes[]: + for b in w.boxes: b:draw() w.goal:draw() w.player:draw() diff --git a/examples/log/log.tm b/examples/log/log.tm index 89557e5..5e32a2b 100644 --- a/examples/log/log.tm +++ b/examples/log/log.tm @@ -17,22 +17,22 @@ func _timestamp(->Text): func info(text:Text, newline=yes): say("$\[2]⚫ $text$\[]", newline) - for file in logfiles[]: + for file in logfiles: file:append("$(_timestamp()) [info] $text$\n") func debug(text:Text, newline=yes): say("$\[32]🟢 $text$\[]", newline) - for file in logfiles[]: + for file in logfiles: file:append("$(_timestamp()) [debug] $text$\n") func warn(text:Text, newline=yes): say("$\[33;1]🟡 $text$\[]", newline) - for file in logfiles[]: + for file in logfiles: file:append("$(_timestamp()) [warn] $text$\n") func error(text:Text, newline=yes): say("$\[31;1]🔴 $text$\[]", newline) - for file in logfiles[]: + for file in logfiles: file:append("$(_timestamp()) [error] $text$\n") func add_logfile(file:Path):