Support iterating over pointers to collections again

This commit is contained in:
Bruce Hill 2024-11-09 17:54:32 -05:00
parent 145a078387
commit 0df908f55f
4 changed files with 47 additions and 56 deletions

View File

@ -1125,7 +1125,6 @@ CORD compile_statement(env_t *env, ast_t *ast)
return compile_statement(env, loop);
}
type_t *iter_t = get_type(env, for_->iter);
env_t *body_scope = for_scope(env, ast);
loop_ctx_t loop_ctx = (loop_ctx_t){
.loop_name="for",
@ -1140,8 +1139,11 @@ CORD compile_statement(env_t *env, ast_t *ast)
naked_body = CORD_all(naked_body, "\n", loop_ctx.skip_label, ": continue;");
CORD stop = loop_ctx.stop_label ? CORD_all("\n", loop_ctx.stop_label, ":;") : CORD_EMPTY;
if (iter_t == RANGE_TYPE) {
CORD range = compile(env, for_->iter);
type_t *iter_t = value_type(get_type(env, for_->iter));
type_t *iter_value_t = value_type(iter_t);
if (iter_value_t == RANGE_TYPE) {
CORD range = compile_to_pointer_depth(env, for_->iter, 0, false);
CORD value = for_->vars ? compile(body_scope, for_->vars->ast) : "i";
if (for_->empty)
code_err(ast, "Ranges are never empty, they always contain at least their starting element");
@ -1157,9 +1159,9 @@ CORD compile_statement(env_t *env, ast_t *ast)
"\n}");
}
switch (iter_t->tag) {
switch (iter_value_t->tag) {
case ArrayType: {
type_t *item_t = Match(iter_t, ArrayType)->item_type;
type_t *item_t = Match(iter_value_t, ArrayType)->item_type;
CORD index = CORD_EMPTY;
CORD value = CORD_EMPTY;
if (for_->vars) {
@ -1175,27 +1177,7 @@ CORD compile_statement(env_t *env, ast_t *ast)
}
CORD loop = CORD_EMPTY;
ast_t *array = for_->iter;
// Micro-optimization: inline the logic for iterating over
// `array:from(i)` and `array:to(i)` because these happen inside
// hot path inner loops and can actually meaningfully affect
// performance:
// if (for_->iter->tag == MethodCall && streq(Match(for_->iter, MethodCall)->name, "to")
// && value_type(get_type(env, Match(for_->iter, MethodCall)->self))->tag == ArrayType) {
// array = Match(for_->iter, MethodCall)->self;
// CORD limit = compile_arguments(env, for_->iter, new(arg_t, .type=INT_TYPE, .name="last"), Match(for_->iter, MethodCall)->args);
// loop = CORD_all(loop, "for (int64_t ", index, " = 1, raw_limit = ", limit,
// ", limit = raw_limit < 0 ? iterating.length + raw_limit + 1 : raw_limit; ",
// index, " <= limit; ++", index, ")");
// } else if (for_->iter->tag == MethodCall && streq(Match(for_->iter, MethodCall)->name, "from")
// && value_type(get_type(env, Match(for_->iter, MethodCall)->self))->tag == ArrayType) {
// array = Match(for_->iter, MethodCall)->self;
// CORD first = compile_arguments(env, for_->iter, new(arg_t, .type=INT_TYPE, .name="last"), Match(for_->iter, MethodCall)->args);
// loop = CORD_all(loop, "for (int64_t first = ", first, ", ", index, " = MAX(1, first < 1 ? iterating.length + first + 1 : first", "); ",
// index, " <= iterating.length; ++", index, ")");
// } else {
loop = CORD_all(loop, "for (int64_t i = 1; i <= iterating.length; ++i)");
// }
loop = CORD_all(loop, "for (int64_t i = 1; i <= iterating.length; ++i)");
if (index != CORD_EMPTY)
naked_body = CORD_all("Int_t ", index, " = I(i);\n", naked_body);
@ -1209,23 +1191,23 @@ CORD compile_statement(env_t *env, ast_t *ast)
loop = CORD_all(loop, "{\n", naked_body, "\n}");
}
if (can_be_mutated(env, array) && is_idempotent(array)) {
CORD array_code = compile(env, array);
if (for_->empty)
loop = CORD_all("if (iterating.length > 0) {\n", loop, "\n} else ", compile_statement(env, for_->empty));
if (iter_t->tag == PointerType) {
loop = CORD_all("{\n"
"Array_t iterating = ARRAY_COPY(", array_code, ");\n",
"Array_t *ptr = ", compile_to_pointer_depth(env, for_->iter, 1, false), ";\n"
"\nARRAY_INCREF(*ptr);\n"
"Array_t iterating = *ptr;\n",
loop,
stop,
"\nARRAY_DECREF(", array_code, ");\n"
"\nARRAY_DECREF(*ptr);\n"
"}\n");
if (for_->empty)
loop = CORD_all("if (", array_code, ".length > 0) {\n", loop, "\n} else ", compile_statement(env, for_->empty));
} else {
loop = CORD_all("{\n"
"Array_t iterating = ", compile(env, array), ";\n",
for_->empty ? "if (iterating.length > 0) {\n" : CORD_EMPTY,
"Array_t iterating = ", compile_to_pointer_depth(env, for_->iter, 0, false), ";\n",
loop,
for_->empty ? CORD_all("\n} else ", compile_statement(env, for_->empty)) : CORD_EMPTY,
stop,
"}\n");
}
@ -1234,16 +1216,16 @@ CORD compile_statement(env_t *env, ast_t *ast)
case SetType: case TableType: {
CORD loop = "for (int64_t i = 0; i < iterating.length; ++i) {\n";
if (for_->vars) {
if (iter_t->tag == SetType) {
if (iter_value_t->tag == SetType) {
if (for_->vars->next)
code_err(for_->vars->next->ast, "This is too many variables for this loop");
CORD item = compile(body_scope, for_->vars->ast);
type_t *item_type = Match(iter_t, SetType)->item_type;
type_t *item_type = Match(iter_value_t, SetType)->item_type;
loop = CORD_all(loop, compile_declaration(item_type, item), " = *(", compile_type(item_type), "*)(",
"iterating.data + i*iterating.stride);\n");
} else {
CORD key = compile(body_scope, for_->vars->ast);
type_t *key_t = Match(iter_t, TableType)->key_type;
type_t *key_t = Match(iter_value_t, TableType)->key_type;
loop = CORD_all(loop, compile_declaration(key_t, key), " = *(", compile_type(key_t), "*)(",
"iterating.data + i*iterating.stride);\n");
@ -1251,7 +1233,7 @@ CORD compile_statement(env_t *env, ast_t *ast)
if (for_->vars->next->next)
code_err(for_->vars->next->next->ast, "This is too many variables for this loop");
type_t *value_t = Match(iter_t, TableType)->value_type;
type_t *value_t = Match(iter_value_t, TableType)->value_type;
CORD value = compile(body_scope, for_->vars->next->ast);
size_t value_offset = type_size(key_t);
if (type_align(value_t) > 1 && value_offset % type_align(value_t))
@ -1268,12 +1250,21 @@ CORD compile_statement(env_t *env, ast_t *ast)
loop = CORD_all("if (iterating.length > 0) {\n", loop, "\n} else ", compile_statement(env, for_->empty));
}
if (can_be_mutated(env, for_->iter) && is_idempotent(for_->iter)) {
if (iter_t->tag == PointerType) {
loop = CORD_all(
"{\n",
"Array_t iterating = ARRAY_COPY((", compile(env, for_->iter), ").entries);\n",
"Table_t *t = ", compile_to_pointer_depth(env, for_->iter, 1, false), ";\n"
"ARRAY_INCREF(t->entries);\n"
"Array_t iterating = t->entries;\n",
loop,
"ARRAY_DECREF(t->entries);\n"
"}\n");
} else if (can_be_mutated(env, for_->iter)) {
loop = CORD_all(
"{\n",
"Table_t t = ", compile_to_pointer_depth(env, for_->iter, 0, false), ";\n"
"Array_t iterating = t.entries;\n",
loop,
"ARRAY_DECREF((", compile(env, for_->iter), ").entries);\n"
"}\n");
} else {
loop = CORD_all(
@ -1312,13 +1303,13 @@ CORD compile_statement(env_t *env, ast_t *ast)
}
big_n:
n = compile(env, for_->iter);
n = compile_to_pointer_depth(env, for_->iter, 0, false);
CORD i = for_->vars ? compile(body_scope, for_->vars->ast) : "i";
CORD n_var = for_->vars ? CORD_all("max", i) : "n";
if (for_->empty) {
return CORD_all(
"{\n"
"Int_t ", n_var, " = ", compile(env, for_->iter), ";\n"
"Int_t ", n_var, " = ", n, ";\n"
"if (Int$compare_value(", n_var, ", I(0)) > 0) {\n"
"for (Int_t ", i, " = I(1); Int$compare_value(", i, ", ", n_var, ") <= 0; ", i, " = Int$plus(", i, ", I(1))) {\n",
"\t", naked_body,
@ -1340,17 +1331,17 @@ CORD compile_statement(env_t *env, ast_t *ast)
CORD next_fn;
if (is_idempotent(for_->iter)) {
next_fn = compile(env, for_->iter);
next_fn = compile_to_pointer_depth(env, for_->iter, 0, false);
} else {
code = CORD_all(code, compile_declaration(iter_t, "next"), " = ", compile(env, for_->iter), ";\n");
code = CORD_all(code, compile_declaration(iter_value_t, "next"), " = ", compile_to_pointer_depth(env, for_->iter, 0, false), ";\n");
next_fn = "next";
}
auto fn = iter_t->tag == ClosureType ? Match(Match(iter_t, ClosureType)->fn, FunctionType) : Match(iter_t, FunctionType);
auto fn = iter_value_t->tag == ClosureType ? Match(Match(iter_value_t, ClosureType)->fn, FunctionType) : Match(iter_value_t, FunctionType);
CORD get_next;
if (iter_t->tag == ClosureType) {
type_t *fn_t = Match(iter_t, ClosureType)->fn;
if (iter_value_t->tag == ClosureType) {
type_t *fn_t = Match(iter_value_t, ClosureType)->fn;
arg_t *closure_fn_args = NULL;
for (arg_t *arg = Match(fn_t, FunctionType)->args; arg; arg = arg->next)
closure_fn_args = new(arg_t, .name=arg->name, .type=arg->type, .default_val=arg->default_val, .next=closure_fn_args);

View File

@ -529,7 +529,7 @@ env_t *fresh_scope(env_t *env)
env_t *for_scope(env_t *env, ast_t *ast)
{
auto for_ = Match(ast, For);
type_t *iter_t = get_type(env, for_->iter);
type_t *iter_t = value_type(get_type(env, for_->iter));
env_t *scope = fresh_scope(env);
if (iter_t == RANGE_TYPE) {

View File

@ -57,11 +57,11 @@ struct World(player:@Player, goal:@Box, boxes:@[@Box], dt_accum=0.0, won=no):
# Resolve player overlapping with any boxes:
for i in 3:
for b in w.boxes[]:
for b in w.boxes:
w.player.pos += STIFFNESS * solve_overlap(w.player.pos, Player.SIZE, b.pos, b.size)
func draw(w:@World):
for b in w.boxes[]:
for b in w.boxes:
b:draw()
w.goal:draw()
w.player:draw()

View File

@ -17,22 +17,22 @@ func _timestamp(->Text):
func info(text:Text, newline=yes):
say("$\[2]⚫ $text$\[]", newline)
for file in logfiles[]:
for file in logfiles:
file:append("$(_timestamp()) [info] $text$\n")
func debug(text:Text, newline=yes):
say("$\[32]🟢 $text$\[]", newline)
for file in logfiles[]:
for file in logfiles:
file:append("$(_timestamp()) [debug] $text$\n")
func warn(text:Text, newline=yes):
say("$\[33;1]🟡 $text$\[]", newline)
for file in logfiles[]:
for file in logfiles:
file:append("$(_timestamp()) [warn] $text$\n")
func error(text:Text, newline=yes):
say("$\[31;1]🔴 $text$\[]", newline)
for file in logfiles[]:
for file in logfiles:
file:append("$(_timestamp()) [error] $text$\n")
func add_logfile(file:Path):