diff options
| -rw-r--r-- | ast.c | 6 | ||||
| -rw-r--r-- | ast.h | 6 | ||||
| -rw-r--r-- | compile.c | 129 | ||||
| -rw-r--r-- | environment.c | 97 | ||||
| -rw-r--r-- | environment.h | 3 | ||||
| -rw-r--r-- | parse.c | 40 | ||||
| -rw-r--r-- | typecheck.c | 8 |
7 files changed, 191 insertions, 98 deletions
@@ -120,7 +120,7 @@ CORD ast_to_xml(ast_t *ast) optional_tagged("default", data.default_value)) T(TableEntry, "<TableEntry>%r%r</TableEntry>", ast_to_xml(data.key), ast_to_xml(data.value)) T(Comprehension, "<Comprehension>%r%r%r%r%r</Comprehension>", optional_tagged("expr", data.expr), - optional_tagged("key", data.key), optional_tagged("value", data.value), optional_tagged("iter", data.iter), + ast_list_to_xml(data.vars), optional_tagged("iter", data.iter), optional_tagged("filter", data.filter)) T(FunctionDef, "<FunctionDef name=\"%r\">%r%r<body>%r</body></FunctionDef>", ast_to_xml(data.name), arg_list_to_xml(data.args), optional_tagged_type("return-type", data.ret_type), ast_to_xml(data.body)) @@ -128,8 +128,8 @@ CORD ast_to_xml(ast_t *ast) T(FunctionCall, "<FunctionCall><function>%r</function>%r</FunctionCall>", ast_to_xml(data.fn), arg_list_to_xml(data.args)) T(MethodCall, "<MethodCall><self>%r</self><method>%s</method>%r</MethodCall>", ast_to_xml(data.self), data.name, arg_list_to_xml(data.args)) T(Block, "<Block>%r</Block>", ast_list_to_xml(data.statements)) - T(For, "<For>%r%r%r%r%r</For>", optional_tagged("index", data.index), optional_tagged("value", data.value), - optional_tagged("iterable", data.iter), optional_tagged("body", data.body), optional_tagged("empty", data.empty)) + T(For, "<For>%r%r%r%r%r</For>", ast_list_to_xml(data.vars), optional_tagged("iterable", data.iter), + optional_tagged("body", data.body), optional_tagged("empty", data.empty)) T(While, "<While>%r%r</While>", optional_tagged("condition", data.condition), optional_tagged("body", data.body)) T(If, "<If>%r%r%r</If>", optional_tagged("condition", data.condition), optional_tagged("body", data.body), optional_tagged("else", data.else_body)) T(When, "<When><subject>%r</subject>%r%r</When>", ast_to_xml(data.subject), when_clauses_to_xml(data.clauses), optional_tagged("else", data.else_body)) @@ -186,7 +186,8 @@ struct ast_s { ast_t *key, *value; } TableEntry; struct { - ast_t *expr, *key, *value, *iter, *filter; + ast_list_t *vars; + ast_t *expr, *iter, *filter; } Comprehension; struct { ast_t *name; @@ -214,7 +215,8 @@ struct ast_s { ast_list_t *statements; } Block; struct { - ast_t *index, *value, *iter, *body, *empty; + ast_list_t *vars; + ast_t *iter, *body, *empty; } For; struct { ast_t *condition, *body; @@ -18,6 +18,7 @@ static CORD compile_to_pointer_depth(env_t *env, ast_t *ast, int64_t target_dept static env_t *with_enum_scope(env_t *env, type_t *t); static CORD compile_math_method(env_t *env, ast_t *ast, binop_e op, ast_t *lhs, ast_t *rhs, type_t *required_type); static CORD compile_string(env_t *env, ast_t *ast, CORD color); +static CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t *call_args); static bool promote(env_t *env, CORD *code, type_t *actual, type_t *needed) { @@ -647,8 +648,11 @@ CORD compile_statement(env_t *env, ast_t *ast) case Skip: { const char *target = Match(ast, Skip)->target; for (loop_ctx_t *ctx = env->loop_ctx; ctx; ctx = ctx->next) { - if (!target || CORD_cmp(target, ctx->loop_name) == 0 - || CORD_cmp(target, ctx->key_name) == 0 || CORD_cmp(target, ctx->value_name) == 0) { + bool matched = !target || CORD_cmp(target, ctx->loop_name) == 0; + for (ast_list_t *var = ctx->loop_vars; var && !matched; var = var->next) + matched = (CORD_cmp(target, Match(var->ast, Var)->name) == 0); + + if (matched) { if (!ctx->skip_label) { static int64_t skip_label_count = 1; CORD_sprintf(&ctx->skip_label, "skip_%ld", skip_label_count); @@ -670,8 +674,11 @@ CORD compile_statement(env_t *env, ast_t *ast) case Stop: { const char *target = Match(ast, Stop)->target; for (loop_ctx_t *ctx = env->loop_ctx; ctx; ctx = ctx->next) { - if (!target || CORD_cmp(target, ctx->loop_name) == 0 - || CORD_cmp(target, ctx->key_name) == 0 || CORD_cmp(target, ctx->value_name) == 0) { + bool matched = !target || CORD_cmp(target, ctx->loop_name) == 0; + for (ast_list_t *var = ctx->loop_vars; var && !matched; var = var->next) + matched = (CORD_cmp(target, Match(var->ast, Var)->name) == 0); + + if (matched) { if (!ctx->stop_label) { static int64_t stop_label_count = 1; CORD_sprintf(&ctx->stop_label, "stop_%ld", stop_label_count); @@ -778,8 +785,7 @@ CORD compile_statement(env_t *env, ast_t *ast) env_t *body_scope = for_scope(env, ast); loop_ctx_t loop_ctx = (loop_ctx_t){ .loop_name="for", - .key_name=for_->index ? Match(for_->index, Var)->name : CORD_EMPTY, - .value_name=for_->value ? Match(for_->value, Var)->name : CORD_EMPTY, + .loop_vars=for_->vars, .deferred=body_scope->deferred, .next=body_scope->loop_ctx, }; @@ -792,8 +798,20 @@ CORD compile_statement(env_t *env, ast_t *ast) switch (iter_t->tag) { case ArrayType: { type_t *item_t = Match(iter_t, ArrayType)->item_type; - CORD index = for_->index ? compile(env, for_->index) : "i"; - CORD value = compile(env, for_->value); + CORD index = "i"; + CORD value = "value"; + if (for_->vars) { + if (for_->vars->next) { + if (for_->vars->next->next) + code_err(for_->vars->next->next->ast, "This is too many variables for this loop"); + + index = compile(env, for_->vars->ast); + value = compile(env, for_->vars->next->ast); + } else { + value = compile(env, for_->vars->ast); + } + } + CORD array = is_idempotent(for_->iter) ? compile(env, for_->iter) : "arr"; CORD loop = CORD_all("ARRAY_INCREF(", array, ");\n" "for (int64_t ", index, " = 1; ", index, " <= ", array, ".length; ++", index, ") {\n", @@ -814,18 +832,30 @@ CORD compile_statement(env_t *env, ast_t *ast) CORD table = is_idempotent(for_->iter) ? compile(env, for_->iter) : "table"; CORD loop = CORD_all("ARRAY_INCREF(", table, ".entries);\n" "for (int64_t i = 0; i < ",table,".entries.length; ++i) {\n"); - if (for_->index) { - loop = CORD_all(loop, compile_type(key_t), " ", compile(env, for_->index), " = *(", compile_type(key_t), "*)(", - table,".entries.data + i*", table, ".entries.stride);\n"); + CORD key = CORD_EMPTY, value = CORD_EMPTY; + if (for_->vars) { + if (for_->vars->next) { + if (for_->vars->next->next) + code_err(for_->vars->next->next->ast, "This is too many variables for this loop"); + + key = compile(env, for_->vars->ast); + value = compile(env, for_->vars->next->ast); + } else { + key = compile(env, for_->vars->ast); + } + } + + if (key) { + loop = CORD_all(loop, compile_type(key_t), " ", key, " = *(", compile_type(key_t), "*)(", + table,".entries.data + i*", table, ".entries.stride);\n"); + } + if (value) { size_t value_offset = type_size(key_t); if (type_align(value_t) > 1 && value_offset % type_align(value_t)) value_offset += type_align(value_t) - (value_offset % type_align(value_t)); // padding - loop = CORD_all(loop, compile_type(value_t), " ", compile(env, for_->value), " = *(", compile_type(value_t), "*)(", + loop = CORD_all(loop, compile_type(value_t), " ", value, " = *(", compile_type(value_t), "*)(", table,".entries.data + i*", table, ".entries.stride + ", heap_strf("%zu", value_offset), ");\n"); - } else { - loop = CORD_all(loop, compile_type(key_t), " ", compile(env, for_->value), " = *(", compile_type(key_t), "*)(", - table,".entries.data + i*", table, ".entries.stride);\n"); } loop = CORD_all(loop, body, "\n}"); if (for_->empty) @@ -836,19 +866,9 @@ CORD compile_statement(env_t *env, ast_t *ast) return loop; } case IntType: { - CORD value = compile(env, for_->value); + CORD value = for_->vars ? compile(env, for_->vars->ast) : "i"; CORD n = compile(env, for_->iter); - CORD index = for_->index ? compile(env, for_->index) : CORD_EMPTY; - if (for_->empty && index) { - return CORD_all( - "{\n" - "int64_t n = ", n, ";\n" - "if (n > 0) {\n" - "for (int64_t ", index, " = 1, ", value, "; (", value, "=", index,") <= n; ++", index, ") {\n" - "\t", body, "\n}" - "\n} else ", compile_statement(env, for_->empty), - stop, "\n}"); - } else if (for_->empty) { + if (for_->empty) { return CORD_all( "{\n" "int64_t n = ", n, ";\n" @@ -859,13 +879,6 @@ CORD compile_statement(env_t *env, ast_t *ast) "\n} else ", compile_statement(env, for_->empty), stop, "\n}"); - } else if (index) { - return CORD_all( - "for (int64_t ", value, ", ", index, " = 1, n = ", n, "; (", value, "=", index,") <= n; ++", value, ") {\n" - "\t", body, - "\n}", - stop, - "\n"); } else { return CORD_all( "for (int64_t ", value, " = 1, n = ", compile(env, for_->iter), "; ", value, " <= n; ++", value, ") {\n" @@ -875,6 +888,44 @@ CORD compile_statement(env_t *env, ast_t *ast) "\n"); } } + case FunctionType: case ClosureType: { + CORD code = "{\n"; + auto fn = iter_t->tag == ClosureType ? Match(Match(iter_t, ClosureType)->fn, FunctionType) : Match(iter_t, FunctionType); + arg_t *next_arg = fn->args; + for (ast_list_t *var = for_->vars; var; var = var->next) { + const char *name = Match(var->ast, Var)->name; + type_t *t = Match(get_arg_type(env, next_arg), PointerType)->pointed; + code = CORD_all(code, compile_declaration(t, CORD_cat("$", name)), ";\n"); + } + + code = CORD_all(code, compile_declaration(iter_t, "next"), " = ", compile(env, for_->iter), ";\n"); + + CORD next_fn; + if (iter_t->tag == ClosureType) { + type_t *fn_t = Match(iter_t, ClosureType)->fn; + arg_t *closure_fn_args = NULL; + for (arg_t *arg = Match(fn_t, FunctionType)->args; arg; arg = arg->next) + closure_fn_args = new(arg_t, .name=arg->name, .type=arg->type, .default_val=arg->default_val, .next=closure_fn_args); + closure_fn_args = new(arg_t, .name="userdata", .type=Type(PointerType, .pointed=Type(MemoryType)), .next=closure_fn_args); + REVERSE_LIST(closure_fn_args); + CORD fn_type_code = compile_type(Type(FunctionType, .args=closure_fn_args, .ret=Match(fn_t, FunctionType)->ret)); + next_fn = CORD_all("((", fn_type_code, ")next.fn)"); + } else { + next_fn = "next"; + } + + code = CORD_all(code, "for(; ", next_fn, "("); + for (ast_list_t *var = for_->vars; var; var = var->next) { + const char *name = Match(var->ast, Var)->name; + code = CORD_all(code, "&$", name); + if (var->next || iter_t->tag == ClosureType) + code = CORD_all(code, ", "); + } + if (iter_t->tag == ClosureType) + code = CORD_all(code, "next.userdata"); + code = CORD_all(code, "); ) {\n\t", body, "}\n", stop, "\n}\n"); + return code; + } default: code_err(for_->iter, "Iteration is not implemented for type: %T", iter_t); } } @@ -914,7 +965,7 @@ CORD compile_statement(env_t *env, ast_t *ast) assert(env->comprehension_var); if (comp->expr->tag == Comprehension) { // Nested comprehension ast_t *body = comp->filter ? WrapAST(ast, If, .condition=comp->filter, .body=comp->expr) : comp->expr; - ast_t *loop = WrapAST(ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body); + ast_t *loop = WrapAST(ast, For, .vars=comp->vars, .iter=comp->iter, .body=body); return compile_statement(env, loop); } else if (comp->expr->tag == TableEntry) { // Table comprehension auto e = Match(comp->expr, TableEntry); @@ -922,14 +973,14 @@ CORD compile_statement(env_t *env, ast_t *ast) .args=new(arg_ast_t, .value=e->key, .next=new(arg_ast_t, .value=e->value))); if (comp->filter) body = WrapAST(body, If, .condition=comp->filter, .body=body); - ast_t *loop = WrapAST(ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body); + ast_t *loop = WrapAST(ast, For, .vars=comp->vars, .iter=comp->iter, .body=body); return compile_statement(env, loop); } else { // Array comprehension ast_t *body = WrapAST(comp->expr, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, env->comprehension_var)), .args=new(arg_ast_t, .value=comp->expr)); if (comp->filter) body = WrapAST(body, If, .condition=comp->filter, .body=body); - ast_t *loop = WrapAST(ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body); + ast_t *loop = WrapAST(ast, For, .vars=comp->vars, .iter=comp->iter, .body=body); return compile_statement(env, loop); } } @@ -1038,7 +1089,7 @@ env_t *with_enum_scope(env_t *env, type_t *t) return env; } -static CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t *call_args) +CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t *call_args) { table_t used_args = {}; CORD code = CORD_EMPTY; @@ -1976,7 +2027,7 @@ CORD compile(env_t *env, ast_t *ast) ast_t *item = FakeAST(Var, "$iter_value"); set_binding(scope, "$iter_value", new(binding_t, .type=t, .code="iter_value")); ast_t *body = FakeAST(InlineCCode, CORD_all("reduction = $$i == 1 ? iter_value : ", compile(scope, reduction->combination), ";")); - ast_t *loop = FakeAST(For, .index=i, .value=item, .iter=reduction->iter, .body=body, .empty=empty); + ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=i, .next=new(ast_list_t, .ast=item)), .iter=reduction->iter, .body=body, .empty=empty); code = CORD_all(code, compile_statement(scope, loop), "\nreduction;})"); return code; } diff --git a/environment.c b/environment.c index b239dd3c..65130c37 100644 --- a/environment.c +++ b/environment.c @@ -286,48 +286,73 @@ env_t *for_scope(env_t *env, ast_t *ast) auto for_ = Match(ast, For); type_t *iter_t = get_type(env, for_->iter); env_t *scope = fresh_scope(env); - const char *value = Match(for_->value, Var)->name; - if (for_->index) { - const char *index = Match(for_->index, Var)->name; - switch (iter_t->tag) { - case ArrayType: { - type_t *item_t = Match(iter_t, ArrayType)->item_type; - set_binding(scope, index, new(binding_t, .type=Type(IntType, .bits=64), .code=CORD_cat("$", index))); - set_binding(scope, value, new(binding_t, .type=item_t, .code=CORD_cat("$", value))); - return scope; + switch (iter_t->tag) { + case ArrayType: { + type_t *item_t = Match(iter_t, ArrayType)->item_type; + const char *vars[2] = {}; + int64_t num_vars = 0; + for (ast_list_t *var = for_->vars; var; var = var->next) { + if (num_vars >= 2) + code_err(var->ast, "This is too many variables for this loop"); + vars[num_vars++] = Match(var->ast, Var)->name; } - case TableType: { - type_t *key_t = Match(iter_t, TableType)->key_type; - type_t *value_t = Match(iter_t, TableType)->value_type; - set_binding(scope, index, new(binding_t, .type=key_t, .code=CORD_cat("$", index))); - set_binding(scope, value, new(binding_t, .type=value_t, .code=CORD_cat("$", value))); - return scope; - } - case IntType: { - set_binding(scope, index, new(binding_t, .type=Type(IntType, .bits=64), .code=CORD_cat("$", index))); - set_binding(scope, value, new(binding_t, .type=iter_t, .code=CORD_cat("$", value))); - return scope; - } - default: code_err(for_->iter, "Iteration is not implemented for type: %T", iter_t); + if (num_vars == 1) { + set_binding(scope, vars[0], new(binding_t, .type=item_t, .code=CORD_cat("$", vars[0]))); + } else if (num_vars == 2) { + set_binding(scope, vars[0], new(binding_t, .type=Type(IntType, .bits=64), .code=CORD_cat("$", vars[0]))); + set_binding(scope, vars[1], new(binding_t, .type=item_t, .code=CORD_cat("$", vars[1]))); } - } else { - switch (iter_t->tag) { - case ArrayType: { - type_t *item_t = Match(iter_t, ArrayType)->item_type; - set_binding(scope, value, new(binding_t, .type=item_t, .code=CORD_cat("$", value))); - return scope; + return scope; + } + case TableType: { + const char *vars[2] = {}; + int64_t num_vars = 0; + for (ast_list_t *var = for_->vars; var; var = var->next) { + if (num_vars >= 2) + code_err(var->ast, "This is too many variables for this loop"); + vars[num_vars++] = Match(var->ast, Var)->name; } - case TableType: { - type_t *key_t = Match(iter_t, TableType)->key_type; - set_binding(scope, value, new(binding_t, .type=key_t, .code=CORD_cat("$", value))); - return scope; + + type_t *key_t = Match(iter_t, TableType)->key_type; + if (num_vars == 1) { + set_binding(scope, vars[0], new(binding_t, .type=key_t, .code=CORD_cat("$", vars[0]))); + } else if (num_vars == 2) { + set_binding(scope, vars[0], new(binding_t, .type=key_t, .code=CORD_cat("$", vars[0]))); + type_t *value_t = Match(iter_t, TableType)->value_type; + set_binding(scope, vars[1], new(binding_t, .type=value_t, .code=CORD_cat("$", vars[1]))); } - case IntType: { - set_binding(scope, value, new(binding_t, .type=iter_t, .code=CORD_cat("$", value))); - return scope; + return scope; + } + case IntType: { + if (for_->vars) { + if (for_->vars->next) + code_err(for_->vars->next->ast, "This is too many variables for this loop"); + const char *var = Match(for_->vars->ast, Var)->name; + set_binding(scope, var, new(binding_t, .type=Type(IntType, .bits=64), .code=CORD_cat("$", var))); } - default: code_err(for_->iter, "Iteration is not implemented for type: %T", iter_t); + return scope; + } + case FunctionType: case ClosureType: { + auto fn = iter_t->tag == ClosureType ? Match(Match(iter_t, ClosureType)->fn, FunctionType) : Match(iter_t, FunctionType); + arg_t *next_arg = fn->args; + for (ast_list_t *var = for_->vars; var; var = var->next) { + if (next_arg == NULL) + code_err(var->ast, "This is too many variables for this iterator function"); + const char *name = Match(var->ast, Var)->name; + type_t *t = get_arg_type(env, next_arg); + if (t->tag != PointerType) + code_err(for_->iter, "This iterator has type %T, but I need all its arguments to be mutable stack pointers", iter_t); + auto ptr = Match(t, PointerType); + if (!ptr->is_stack || ptr->is_readonly) + code_err(for_->iter, "This iterator has type %T, but I need all its arguments to be mutable stack pointers", iter_t); + set_binding(scope, name, new(binding_t, .type=ptr->pointed, .code=CORD_cat("$", name))); + next_arg = next_arg->next; } + if (next_arg) + code_err(ast, "There are not enough variables given for this loop with an iterator that has type %T", iter_t); + return scope; + } + default: code_err(for_->iter, "Iteration is not implemented for type: %T", iter_t); } } diff --git a/environment.h b/environment.h index a06f3d42..be87b857 100644 --- a/environment.h +++ b/environment.h @@ -28,7 +28,8 @@ typedef struct deferral_s { typedef struct loop_ctx_s { struct loop_ctx_s *next; - const char *loop_name, *key_name, *value_name; + const char *loop_name; + ast_list_t *loop_vars; deferral_t *deferred; CORD skip_label, stop_label; } loop_ctx_t; @@ -858,22 +858,25 @@ ast_t *parse_comprehension_suffix(parse_ctx_t *ctx, ast_t *expr) { whitespace(&pos); if (!match_word(&pos, "for")) return NULL; - ast_t *key = expect(ctx, start, &pos, parse_var, "I expected an iteration variable for this 'for'"); - whitespace(&pos); - ast_t *value = NULL; - if (match(&pos, ",")) { - value = expect(ctx, pos-1, &pos, parse_var, "I expected a variable after this comma"); - } else { - value = key; - key = NULL; + ast_list_t *vars = NULL; + for (;;) { + ast_t *var = optional(ctx, &pos, parse_var); + if (var) + vars = new(ast_list_t, .ast=var, .next=vars); + + spaces(&pos); + if (!match(&pos, ",")) + break; } + REVERSE_LIST(vars); + expect_str(ctx, start, &pos, "in", "I expected an 'in' for this 'for'"); ast_t *iter = expect(ctx, start, &pos, parse_expr, "I expected an iterable value for this 'for'"); whitespace(&pos); ast_t *filter = NULL; if (match_word(&pos, "if")) filter = expect(ctx, pos-2, &pos, parse_expr, "I expected a condition for this 'if'"); - return NewAST(ctx->file, start, pos, Comprehension, .expr=expr, .key=key, .value=value, .iter=iter, .filter=filter); + return NewAST(ctx->file, start, pos, Comprehension, .expr=expr, .vars=vars, .iter=iter, .filter=filter); } PARSER(parse_if) { @@ -968,13 +971,21 @@ PARSER(parse_for) { const char *start = pos; if (!match_word(&pos, "for")) return NULL; int64_t starting_indent = get_indent(ctx, pos); - ast_t *index = expect(ctx, start, &pos, parse_var, "I expected an iteration variable for this 'for'"); spaces(&pos); - ast_t *value = NULL; - if (match(&pos, ",")) { - value = expect(ctx, pos-1, &pos, parse_var, "I expected a variable after this comma"); + ast_list_t *vars = NULL; + for (;;) { + ast_t *var = optional(ctx, &pos, parse_var); + if (var) + vars = new(ast_list_t, .ast=var, .next=vars); + + spaces(&pos); + if (!match(&pos, ",")) + break; } + + spaces(&pos); expect_str(ctx, start, &pos, "in", "I expected an 'in' for this 'for'"); + ast_t *iter = expect(ctx, start, &pos, parse_expr, "I expected an iterable value for this 'for'"); ast_t *body = expect(ctx, start, &pos, parse_block, "I expected a body for this 'for'"); @@ -985,7 +996,8 @@ PARSER(parse_for) { pos = else_start; empty = expect(ctx, pos, &pos, parse_block, "I expected a body for this 'else'"); } - return NewAST(ctx->file, start, pos, For, .index=value ? index : NULL, .value=value ? value : index, .iter=iter, .body=body, .empty=empty); + REVERSE_LIST(vars); + return NewAST(ctx->file, start, pos, For, .vars=vars, .iter=iter, .body=body, .empty=empty); } PARSER(parse_do) { diff --git a/typecheck.c b/typecheck.c index 43e384f4..ba04ad53 100644 --- a/typecheck.c +++ b/typecheck.c @@ -509,7 +509,8 @@ type_t *get_type(env_t *env, ast_t *ast) env_t *scope = env; while (item_ast->tag == Comprehension) { auto comp = Match(item_ast, Comprehension); - scope = for_scope(scope, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value)); + scope = for_scope( + scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars)); item_ast = comp->expr; } type_t *t2 = get_type(scope, item_ast); @@ -541,7 +542,8 @@ type_t *get_type(env_t *env, ast_t *ast) env_t *scope = env; while (entry_ast->tag == Comprehension) { auto comp = Match(entry_ast, Comprehension); - scope = for_scope(scope, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value)); + scope = for_scope( + scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars)); entry_ast = comp->expr; } @@ -573,7 +575,7 @@ type_t *get_type(env_t *env, ast_t *ast) } case Comprehension: { auto comp = Match(ast, Comprehension); - env_t *scope = for_scope(env, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value)); + env_t *scope = for_scope(env, FakeAST(For, .iter=comp->iter, .vars=comp->vars)); if (comp->expr->tag == Comprehension) { return get_type(scope, comp->expr); } else if (comp->expr->tag == TableEntry) { |
