From 445f79cb70e72698283539b65e43fc71a47ad311 Mon Sep 17 00:00:00 2001 From: Bruce Hill Date: Sat, 13 Jul 2024 17:17:58 -0400 Subject: Add iterator functions --- compile.c | 129 +++++++++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 90 insertions(+), 39 deletions(-) (limited to 'compile.c') diff --git a/compile.c b/compile.c index cac7a345..9db1642e 100644 --- a/compile.c +++ b/compile.c @@ -18,6 +18,7 @@ static CORD compile_to_pointer_depth(env_t *env, ast_t *ast, int64_t target_dept static env_t *with_enum_scope(env_t *env, type_t *t); static CORD compile_math_method(env_t *env, ast_t *ast, binop_e op, ast_t *lhs, ast_t *rhs, type_t *required_type); static CORD compile_string(env_t *env, ast_t *ast, CORD color); +static CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t *call_args); static bool promote(env_t *env, CORD *code, type_t *actual, type_t *needed) { @@ -647,8 +648,11 @@ CORD compile_statement(env_t *env, ast_t *ast) case Skip: { const char *target = Match(ast, Skip)->target; for (loop_ctx_t *ctx = env->loop_ctx; ctx; ctx = ctx->next) { - if (!target || CORD_cmp(target, ctx->loop_name) == 0 - || CORD_cmp(target, ctx->key_name) == 0 || CORD_cmp(target, ctx->value_name) == 0) { + bool matched = !target || CORD_cmp(target, ctx->loop_name) == 0; + for (ast_list_t *var = ctx->loop_vars; var && !matched; var = var->next) + matched = (CORD_cmp(target, Match(var->ast, Var)->name) == 0); + + if (matched) { if (!ctx->skip_label) { static int64_t skip_label_count = 1; CORD_sprintf(&ctx->skip_label, "skip_%ld", skip_label_count); @@ -670,8 +674,11 @@ CORD compile_statement(env_t *env, ast_t *ast) case Stop: { const char *target = Match(ast, Stop)->target; for (loop_ctx_t *ctx = env->loop_ctx; ctx; ctx = ctx->next) { - if (!target || CORD_cmp(target, ctx->loop_name) == 0 - || CORD_cmp(target, ctx->key_name) == 0 || CORD_cmp(target, ctx->value_name) == 0) { + bool matched = !target || CORD_cmp(target, ctx->loop_name) == 0; + for (ast_list_t *var = ctx->loop_vars; var && !matched; var = var->next) + matched = (CORD_cmp(target, Match(var->ast, Var)->name) == 0); + + if (matched) { if (!ctx->stop_label) { static int64_t stop_label_count = 1; CORD_sprintf(&ctx->stop_label, "stop_%ld", stop_label_count); @@ -778,8 +785,7 @@ CORD compile_statement(env_t *env, ast_t *ast) env_t *body_scope = for_scope(env, ast); loop_ctx_t loop_ctx = (loop_ctx_t){ .loop_name="for", - .key_name=for_->index ? Match(for_->index, Var)->name : CORD_EMPTY, - .value_name=for_->value ? Match(for_->value, Var)->name : CORD_EMPTY, + .loop_vars=for_->vars, .deferred=body_scope->deferred, .next=body_scope->loop_ctx, }; @@ -792,8 +798,20 @@ CORD compile_statement(env_t *env, ast_t *ast) switch (iter_t->tag) { case ArrayType: { type_t *item_t = Match(iter_t, ArrayType)->item_type; - CORD index = for_->index ? compile(env, for_->index) : "i"; - CORD value = compile(env, for_->value); + CORD index = "i"; + CORD value = "value"; + if (for_->vars) { + if (for_->vars->next) { + if (for_->vars->next->next) + code_err(for_->vars->next->next->ast, "This is too many variables for this loop"); + + index = compile(env, for_->vars->ast); + value = compile(env, for_->vars->next->ast); + } else { + value = compile(env, for_->vars->ast); + } + } + CORD array = is_idempotent(for_->iter) ? compile(env, for_->iter) : "arr"; CORD loop = CORD_all("ARRAY_INCREF(", array, ");\n" "for (int64_t ", index, " = 1; ", index, " <= ", array, ".length; ++", index, ") {\n", @@ -814,18 +832,30 @@ CORD compile_statement(env_t *env, ast_t *ast) CORD table = is_idempotent(for_->iter) ? compile(env, for_->iter) : "table"; CORD loop = CORD_all("ARRAY_INCREF(", table, ".entries);\n" "for (int64_t i = 0; i < ",table,".entries.length; ++i) {\n"); - if (for_->index) { - loop = CORD_all(loop, compile_type(key_t), " ", compile(env, for_->index), " = *(", compile_type(key_t), "*)(", - table,".entries.data + i*", table, ".entries.stride);\n"); + CORD key = CORD_EMPTY, value = CORD_EMPTY; + if (for_->vars) { + if (for_->vars->next) { + if (for_->vars->next->next) + code_err(for_->vars->next->next->ast, "This is too many variables for this loop"); + + key = compile(env, for_->vars->ast); + value = compile(env, for_->vars->next->ast); + } else { + key = compile(env, for_->vars->ast); + } + } + + if (key) { + loop = CORD_all(loop, compile_type(key_t), " ", key, " = *(", compile_type(key_t), "*)(", + table,".entries.data + i*", table, ".entries.stride);\n"); + } + if (value) { size_t value_offset = type_size(key_t); if (type_align(value_t) > 1 && value_offset % type_align(value_t)) value_offset += type_align(value_t) - (value_offset % type_align(value_t)); // padding - loop = CORD_all(loop, compile_type(value_t), " ", compile(env, for_->value), " = *(", compile_type(value_t), "*)(", + loop = CORD_all(loop, compile_type(value_t), " ", value, " = *(", compile_type(value_t), "*)(", table,".entries.data + i*", table, ".entries.stride + ", heap_strf("%zu", value_offset), ");\n"); - } else { - loop = CORD_all(loop, compile_type(key_t), " ", compile(env, for_->value), " = *(", compile_type(key_t), "*)(", - table,".entries.data + i*", table, ".entries.stride);\n"); } loop = CORD_all(loop, body, "\n}"); if (for_->empty) @@ -836,19 +866,9 @@ CORD compile_statement(env_t *env, ast_t *ast) return loop; } case IntType: { - CORD value = compile(env, for_->value); + CORD value = for_->vars ? compile(env, for_->vars->ast) : "i"; CORD n = compile(env, for_->iter); - CORD index = for_->index ? compile(env, for_->index) : CORD_EMPTY; - if (for_->empty && index) { - return CORD_all( - "{\n" - "int64_t n = ", n, ";\n" - "if (n > 0) {\n" - "for (int64_t ", index, " = 1, ", value, "; (", value, "=", index,") <= n; ++", index, ") {\n" - "\t", body, "\n}" - "\n} else ", compile_statement(env, for_->empty), - stop, "\n}"); - } else if (for_->empty) { + if (for_->empty) { return CORD_all( "{\n" "int64_t n = ", n, ";\n" @@ -859,13 +879,6 @@ CORD compile_statement(env_t *env, ast_t *ast) "\n} else ", compile_statement(env, for_->empty), stop, "\n}"); - } else if (index) { - return CORD_all( - "for (int64_t ", value, ", ", index, " = 1, n = ", n, "; (", value, "=", index,") <= n; ++", value, ") {\n" - "\t", body, - "\n}", - stop, - "\n"); } else { return CORD_all( "for (int64_t ", value, " = 1, n = ", compile(env, for_->iter), "; ", value, " <= n; ++", value, ") {\n" @@ -875,6 +888,44 @@ CORD compile_statement(env_t *env, ast_t *ast) "\n"); } } + case FunctionType: case ClosureType: { + CORD code = "{\n"; + auto fn = iter_t->tag == ClosureType ? Match(Match(iter_t, ClosureType)->fn, FunctionType) : Match(iter_t, FunctionType); + arg_t *next_arg = fn->args; + for (ast_list_t *var = for_->vars; var; var = var->next) { + const char *name = Match(var->ast, Var)->name; + type_t *t = Match(get_arg_type(env, next_arg), PointerType)->pointed; + code = CORD_all(code, compile_declaration(t, CORD_cat("$", name)), ";\n"); + } + + code = CORD_all(code, compile_declaration(iter_t, "next"), " = ", compile(env, for_->iter), ";\n"); + + CORD next_fn; + if (iter_t->tag == ClosureType) { + type_t *fn_t = Match(iter_t, ClosureType)->fn; + arg_t *closure_fn_args = NULL; + for (arg_t *arg = Match(fn_t, FunctionType)->args; arg; arg = arg->next) + closure_fn_args = new(arg_t, .name=arg->name, .type=arg->type, .default_val=arg->default_val, .next=closure_fn_args); + closure_fn_args = new(arg_t, .name="userdata", .type=Type(PointerType, .pointed=Type(MemoryType)), .next=closure_fn_args); + REVERSE_LIST(closure_fn_args); + CORD fn_type_code = compile_type(Type(FunctionType, .args=closure_fn_args, .ret=Match(fn_t, FunctionType)->ret)); + next_fn = CORD_all("((", fn_type_code, ")next.fn)"); + } else { + next_fn = "next"; + } + + code = CORD_all(code, "for(; ", next_fn, "("); + for (ast_list_t *var = for_->vars; var; var = var->next) { + const char *name = Match(var->ast, Var)->name; + code = CORD_all(code, "&$", name); + if (var->next || iter_t->tag == ClosureType) + code = CORD_all(code, ", "); + } + if (iter_t->tag == ClosureType) + code = CORD_all(code, "next.userdata"); + code = CORD_all(code, "); ) {\n\t", body, "}\n", stop, "\n}\n"); + return code; + } default: code_err(for_->iter, "Iteration is not implemented for type: %T", iter_t); } } @@ -914,7 +965,7 @@ CORD compile_statement(env_t *env, ast_t *ast) assert(env->comprehension_var); if (comp->expr->tag == Comprehension) { // Nested comprehension ast_t *body = comp->filter ? WrapAST(ast, If, .condition=comp->filter, .body=comp->expr) : comp->expr; - ast_t *loop = WrapAST(ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body); + ast_t *loop = WrapAST(ast, For, .vars=comp->vars, .iter=comp->iter, .body=body); return compile_statement(env, loop); } else if (comp->expr->tag == TableEntry) { // Table comprehension auto e = Match(comp->expr, TableEntry); @@ -922,14 +973,14 @@ CORD compile_statement(env_t *env, ast_t *ast) .args=new(arg_ast_t, .value=e->key, .next=new(arg_ast_t, .value=e->value))); if (comp->filter) body = WrapAST(body, If, .condition=comp->filter, .body=body); - ast_t *loop = WrapAST(ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body); + ast_t *loop = WrapAST(ast, For, .vars=comp->vars, .iter=comp->iter, .body=body); return compile_statement(env, loop); } else { // Array comprehension ast_t *body = WrapAST(comp->expr, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, env->comprehension_var)), .args=new(arg_ast_t, .value=comp->expr)); if (comp->filter) body = WrapAST(body, If, .condition=comp->filter, .body=body); - ast_t *loop = WrapAST(ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body); + ast_t *loop = WrapAST(ast, For, .vars=comp->vars, .iter=comp->iter, .body=body); return compile_statement(env, loop); } } @@ -1038,7 +1089,7 @@ env_t *with_enum_scope(env_t *env, type_t *t) return env; } -static CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t *call_args) +CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t *call_args) { table_t used_args = {}; CORD code = CORD_EMPTY; @@ -1976,7 +2027,7 @@ CORD compile(env_t *env, ast_t *ast) ast_t *item = FakeAST(Var, "$iter_value"); set_binding(scope, "$iter_value", new(binding_t, .type=t, .code="iter_value")); ast_t *body = FakeAST(InlineCCode, CORD_all("reduction = $$i == 1 ? iter_value : ", compile(scope, reduction->combination), ";")); - ast_t *loop = FakeAST(For, .index=i, .value=item, .iter=reduction->iter, .body=body, .empty=empty); + ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=i, .next=new(ast_list_t, .ast=item)), .iter=reduction->iter, .body=body, .empty=empty); code = CORD_all(code, compile_statement(scope, loop), "\nreduction;})"); return code; } -- cgit v1.2.3