aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2024-07-13 17:17:58 -0400
committerBruce Hill <bruce@bruce-hill.com>2024-07-13 17:17:58 -0400
commit445f79cb70e72698283539b65e43fc71a47ad311 (patch)
tree9a1b0b027a1957fc0f6351e21ef53ce7ff53259a
parent3db57b4d2e16ab25fbd07401ec7b3a738f8dae8a (diff)
Add iterator functions
-rw-r--r--ast.c6
-rw-r--r--ast.h6
-rw-r--r--compile.c129
-rw-r--r--environment.c97
-rw-r--r--environment.h3
-rw-r--r--parse.c40
-rw-r--r--typecheck.c8
7 files changed, 191 insertions, 98 deletions
diff --git a/ast.c b/ast.c
index d4cf4140..73556e49 100644
--- a/ast.c
+++ b/ast.c
@@ -120,7 +120,7 @@ CORD ast_to_xml(ast_t *ast)
optional_tagged("default", data.default_value))
T(TableEntry, "<TableEntry>%r%r</TableEntry>", ast_to_xml(data.key), ast_to_xml(data.value))
T(Comprehension, "<Comprehension>%r%r%r%r%r</Comprehension>", optional_tagged("expr", data.expr),
- optional_tagged("key", data.key), optional_tagged("value", data.value), optional_tagged("iter", data.iter),
+ ast_list_to_xml(data.vars), optional_tagged("iter", data.iter),
optional_tagged("filter", data.filter))
T(FunctionDef, "<FunctionDef name=\"%r\">%r%r<body>%r</body></FunctionDef>", ast_to_xml(data.name),
arg_list_to_xml(data.args), optional_tagged_type("return-type", data.ret_type), ast_to_xml(data.body))
@@ -128,8 +128,8 @@ CORD ast_to_xml(ast_t *ast)
T(FunctionCall, "<FunctionCall><function>%r</function>%r</FunctionCall>", ast_to_xml(data.fn), arg_list_to_xml(data.args))
T(MethodCall, "<MethodCall><self>%r</self><method>%s</method>%r</MethodCall>", ast_to_xml(data.self), data.name, arg_list_to_xml(data.args))
T(Block, "<Block>%r</Block>", ast_list_to_xml(data.statements))
- T(For, "<For>%r%r%r%r%r</For>", optional_tagged("index", data.index), optional_tagged("value", data.value),
- optional_tagged("iterable", data.iter), optional_tagged("body", data.body), optional_tagged("empty", data.empty))
+ T(For, "<For>%r%r%r%r%r</For>", ast_list_to_xml(data.vars), optional_tagged("iterable", data.iter),
+ optional_tagged("body", data.body), optional_tagged("empty", data.empty))
T(While, "<While>%r%r</While>", optional_tagged("condition", data.condition), optional_tagged("body", data.body))
T(If, "<If>%r%r%r</If>", optional_tagged("condition", data.condition), optional_tagged("body", data.body), optional_tagged("else", data.else_body))
T(When, "<When><subject>%r</subject>%r%r</When>", ast_to_xml(data.subject), when_clauses_to_xml(data.clauses), optional_tagged("else", data.else_body))
diff --git a/ast.h b/ast.h
index 849f91c8..be37c8cb 100644
--- a/ast.h
+++ b/ast.h
@@ -186,7 +186,8 @@ struct ast_s {
ast_t *key, *value;
} TableEntry;
struct {
- ast_t *expr, *key, *value, *iter, *filter;
+ ast_list_t *vars;
+ ast_t *expr, *iter, *filter;
} Comprehension;
struct {
ast_t *name;
@@ -214,7 +215,8 @@ struct ast_s {
ast_list_t *statements;
} Block;
struct {
- ast_t *index, *value, *iter, *body, *empty;
+ ast_list_t *vars;
+ ast_t *iter, *body, *empty;
} For;
struct {
ast_t *condition, *body;
diff --git a/compile.c b/compile.c
index cac7a345..9db1642e 100644
--- a/compile.c
+++ b/compile.c
@@ -18,6 +18,7 @@ static CORD compile_to_pointer_depth(env_t *env, ast_t *ast, int64_t target_dept
static env_t *with_enum_scope(env_t *env, type_t *t);
static CORD compile_math_method(env_t *env, ast_t *ast, binop_e op, ast_t *lhs, ast_t *rhs, type_t *required_type);
static CORD compile_string(env_t *env, ast_t *ast, CORD color);
+static CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t *call_args);
static bool promote(env_t *env, CORD *code, type_t *actual, type_t *needed)
{
@@ -647,8 +648,11 @@ CORD compile_statement(env_t *env, ast_t *ast)
case Skip: {
const char *target = Match(ast, Skip)->target;
for (loop_ctx_t *ctx = env->loop_ctx; ctx; ctx = ctx->next) {
- if (!target || CORD_cmp(target, ctx->loop_name) == 0
- || CORD_cmp(target, ctx->key_name) == 0 || CORD_cmp(target, ctx->value_name) == 0) {
+ bool matched = !target || CORD_cmp(target, ctx->loop_name) == 0;
+ for (ast_list_t *var = ctx->loop_vars; var && !matched; var = var->next)
+ matched = (CORD_cmp(target, Match(var->ast, Var)->name) == 0);
+
+ if (matched) {
if (!ctx->skip_label) {
static int64_t skip_label_count = 1;
CORD_sprintf(&ctx->skip_label, "skip_%ld", skip_label_count);
@@ -670,8 +674,11 @@ CORD compile_statement(env_t *env, ast_t *ast)
case Stop: {
const char *target = Match(ast, Stop)->target;
for (loop_ctx_t *ctx = env->loop_ctx; ctx; ctx = ctx->next) {
- if (!target || CORD_cmp(target, ctx->loop_name) == 0
- || CORD_cmp(target, ctx->key_name) == 0 || CORD_cmp(target, ctx->value_name) == 0) {
+ bool matched = !target || CORD_cmp(target, ctx->loop_name) == 0;
+ for (ast_list_t *var = ctx->loop_vars; var && !matched; var = var->next)
+ matched = (CORD_cmp(target, Match(var->ast, Var)->name) == 0);
+
+ if (matched) {
if (!ctx->stop_label) {
static int64_t stop_label_count = 1;
CORD_sprintf(&ctx->stop_label, "stop_%ld", stop_label_count);
@@ -778,8 +785,7 @@ CORD compile_statement(env_t *env, ast_t *ast)
env_t *body_scope = for_scope(env, ast);
loop_ctx_t loop_ctx = (loop_ctx_t){
.loop_name="for",
- .key_name=for_->index ? Match(for_->index, Var)->name : CORD_EMPTY,
- .value_name=for_->value ? Match(for_->value, Var)->name : CORD_EMPTY,
+ .loop_vars=for_->vars,
.deferred=body_scope->deferred,
.next=body_scope->loop_ctx,
};
@@ -792,8 +798,20 @@ CORD compile_statement(env_t *env, ast_t *ast)
switch (iter_t->tag) {
case ArrayType: {
type_t *item_t = Match(iter_t, ArrayType)->item_type;
- CORD index = for_->index ? compile(env, for_->index) : "i";
- CORD value = compile(env, for_->value);
+ CORD index = "i";
+ CORD value = "value";
+ if (for_->vars) {
+ if (for_->vars->next) {
+ if (for_->vars->next->next)
+ code_err(for_->vars->next->next->ast, "This is too many variables for this loop");
+
+ index = compile(env, for_->vars->ast);
+ value = compile(env, for_->vars->next->ast);
+ } else {
+ value = compile(env, for_->vars->ast);
+ }
+ }
+
CORD array = is_idempotent(for_->iter) ? compile(env, for_->iter) : "arr";
CORD loop = CORD_all("ARRAY_INCREF(", array, ");\n"
"for (int64_t ", index, " = 1; ", index, " <= ", array, ".length; ++", index, ") {\n",
@@ -814,18 +832,30 @@ CORD compile_statement(env_t *env, ast_t *ast)
CORD table = is_idempotent(for_->iter) ? compile(env, for_->iter) : "table";
CORD loop = CORD_all("ARRAY_INCREF(", table, ".entries);\n"
"for (int64_t i = 0; i < ",table,".entries.length; ++i) {\n");
- if (for_->index) {
- loop = CORD_all(loop, compile_type(key_t), " ", compile(env, for_->index), " = *(", compile_type(key_t), "*)(",
- table,".entries.data + i*", table, ".entries.stride);\n");
+ CORD key = CORD_EMPTY, value = CORD_EMPTY;
+ if (for_->vars) {
+ if (for_->vars->next) {
+ if (for_->vars->next->next)
+ code_err(for_->vars->next->next->ast, "This is too many variables for this loop");
+
+ key = compile(env, for_->vars->ast);
+ value = compile(env, for_->vars->next->ast);
+ } else {
+ key = compile(env, for_->vars->ast);
+ }
+ }
+
+ if (key) {
+ loop = CORD_all(loop, compile_type(key_t), " ", key, " = *(", compile_type(key_t), "*)(",
+ table,".entries.data + i*", table, ".entries.stride);\n");
+ }
+ if (value) {
size_t value_offset = type_size(key_t);
if (type_align(value_t) > 1 && value_offset % type_align(value_t))
value_offset += type_align(value_t) - (value_offset % type_align(value_t)); // padding
- loop = CORD_all(loop, compile_type(value_t), " ", compile(env, for_->value), " = *(", compile_type(value_t), "*)(",
+ loop = CORD_all(loop, compile_type(value_t), " ", value, " = *(", compile_type(value_t), "*)(",
table,".entries.data + i*", table, ".entries.stride + ", heap_strf("%zu", value_offset), ");\n");
- } else {
- loop = CORD_all(loop, compile_type(key_t), " ", compile(env, for_->value), " = *(", compile_type(key_t), "*)(",
- table,".entries.data + i*", table, ".entries.stride);\n");
}
loop = CORD_all(loop, body, "\n}");
if (for_->empty)
@@ -836,19 +866,9 @@ CORD compile_statement(env_t *env, ast_t *ast)
return loop;
}
case IntType: {
- CORD value = compile(env, for_->value);
+ CORD value = for_->vars ? compile(env, for_->vars->ast) : "i";
CORD n = compile(env, for_->iter);
- CORD index = for_->index ? compile(env, for_->index) : CORD_EMPTY;
- if (for_->empty && index) {
- return CORD_all(
- "{\n"
- "int64_t n = ", n, ";\n"
- "if (n > 0) {\n"
- "for (int64_t ", index, " = 1, ", value, "; (", value, "=", index,") <= n; ++", index, ") {\n"
- "\t", body, "\n}"
- "\n} else ", compile_statement(env, for_->empty),
- stop, "\n}");
- } else if (for_->empty) {
+ if (for_->empty) {
return CORD_all(
"{\n"
"int64_t n = ", n, ";\n"
@@ -859,13 +879,6 @@ CORD compile_statement(env_t *env, ast_t *ast)
"\n} else ", compile_statement(env, for_->empty),
stop,
"\n}");
- } else if (index) {
- return CORD_all(
- "for (int64_t ", value, ", ", index, " = 1, n = ", n, "; (", value, "=", index,") <= n; ++", value, ") {\n"
- "\t", body,
- "\n}",
- stop,
- "\n");
} else {
return CORD_all(
"for (int64_t ", value, " = 1, n = ", compile(env, for_->iter), "; ", value, " <= n; ++", value, ") {\n"
@@ -875,6 +888,44 @@ CORD compile_statement(env_t *env, ast_t *ast)
"\n");
}
}
+ case FunctionType: case ClosureType: {
+ CORD code = "{\n";
+ auto fn = iter_t->tag == ClosureType ? Match(Match(iter_t, ClosureType)->fn, FunctionType) : Match(iter_t, FunctionType);
+ arg_t *next_arg = fn->args;
+ for (ast_list_t *var = for_->vars; var; var = var->next) {
+ const char *name = Match(var->ast, Var)->name;
+ type_t *t = Match(get_arg_type(env, next_arg), PointerType)->pointed;
+ code = CORD_all(code, compile_declaration(t, CORD_cat("$", name)), ";\n");
+ }
+
+ code = CORD_all(code, compile_declaration(iter_t, "next"), " = ", compile(env, for_->iter), ";\n");
+
+ CORD next_fn;
+ if (iter_t->tag == ClosureType) {
+ type_t *fn_t = Match(iter_t, ClosureType)->fn;
+ arg_t *closure_fn_args = NULL;
+ for (arg_t *arg = Match(fn_t, FunctionType)->args; arg; arg = arg->next)
+ closure_fn_args = new(arg_t, .name=arg->name, .type=arg->type, .default_val=arg->default_val, .next=closure_fn_args);
+ closure_fn_args = new(arg_t, .name="userdata", .type=Type(PointerType, .pointed=Type(MemoryType)), .next=closure_fn_args);
+ REVERSE_LIST(closure_fn_args);
+ CORD fn_type_code = compile_type(Type(FunctionType, .args=closure_fn_args, .ret=Match(fn_t, FunctionType)->ret));
+ next_fn = CORD_all("((", fn_type_code, ")next.fn)");
+ } else {
+ next_fn = "next";
+ }
+
+ code = CORD_all(code, "for(; ", next_fn, "(");
+ for (ast_list_t *var = for_->vars; var; var = var->next) {
+ const char *name = Match(var->ast, Var)->name;
+ code = CORD_all(code, "&$", name);
+ if (var->next || iter_t->tag == ClosureType)
+ code = CORD_all(code, ", ");
+ }
+ if (iter_t->tag == ClosureType)
+ code = CORD_all(code, "next.userdata");
+ code = CORD_all(code, "); ) {\n\t", body, "}\n", stop, "\n}\n");
+ return code;
+ }
default: code_err(for_->iter, "Iteration is not implemented for type: %T", iter_t);
}
}
@@ -914,7 +965,7 @@ CORD compile_statement(env_t *env, ast_t *ast)
assert(env->comprehension_var);
if (comp->expr->tag == Comprehension) { // Nested comprehension
ast_t *body = comp->filter ? WrapAST(ast, If, .condition=comp->filter, .body=comp->expr) : comp->expr;
- ast_t *loop = WrapAST(ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body);
+ ast_t *loop = WrapAST(ast, For, .vars=comp->vars, .iter=comp->iter, .body=body);
return compile_statement(env, loop);
} else if (comp->expr->tag == TableEntry) { // Table comprehension
auto e = Match(comp->expr, TableEntry);
@@ -922,14 +973,14 @@ CORD compile_statement(env_t *env, ast_t *ast)
.args=new(arg_ast_t, .value=e->key, .next=new(arg_ast_t, .value=e->value)));
if (comp->filter)
body = WrapAST(body, If, .condition=comp->filter, .body=body);
- ast_t *loop = WrapAST(ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body);
+ ast_t *loop = WrapAST(ast, For, .vars=comp->vars, .iter=comp->iter, .body=body);
return compile_statement(env, loop);
} else { // Array comprehension
ast_t *body = WrapAST(comp->expr, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, env->comprehension_var)),
.args=new(arg_ast_t, .value=comp->expr));
if (comp->filter)
body = WrapAST(body, If, .condition=comp->filter, .body=body);
- ast_t *loop = WrapAST(ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body);
+ ast_t *loop = WrapAST(ast, For, .vars=comp->vars, .iter=comp->iter, .body=body);
return compile_statement(env, loop);
}
}
@@ -1038,7 +1089,7 @@ env_t *with_enum_scope(env_t *env, type_t *t)
return env;
}
-static CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t *call_args)
+CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t *call_args)
{
table_t used_args = {};
CORD code = CORD_EMPTY;
@@ -1976,7 +2027,7 @@ CORD compile(env_t *env, ast_t *ast)
ast_t *item = FakeAST(Var, "$iter_value");
set_binding(scope, "$iter_value", new(binding_t, .type=t, .code="iter_value"));
ast_t *body = FakeAST(InlineCCode, CORD_all("reduction = $$i == 1 ? iter_value : ", compile(scope, reduction->combination), ";"));
- ast_t *loop = FakeAST(For, .index=i, .value=item, .iter=reduction->iter, .body=body, .empty=empty);
+ ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=i, .next=new(ast_list_t, .ast=item)), .iter=reduction->iter, .body=body, .empty=empty);
code = CORD_all(code, compile_statement(scope, loop), "\nreduction;})");
return code;
}
diff --git a/environment.c b/environment.c
index b239dd3c..65130c37 100644
--- a/environment.c
+++ b/environment.c
@@ -286,48 +286,73 @@ env_t *for_scope(env_t *env, ast_t *ast)
auto for_ = Match(ast, For);
type_t *iter_t = get_type(env, for_->iter);
env_t *scope = fresh_scope(env);
- const char *value = Match(for_->value, Var)->name;
- if (for_->index) {
- const char *index = Match(for_->index, Var)->name;
- switch (iter_t->tag) {
- case ArrayType: {
- type_t *item_t = Match(iter_t, ArrayType)->item_type;
- set_binding(scope, index, new(binding_t, .type=Type(IntType, .bits=64), .code=CORD_cat("$", index)));
- set_binding(scope, value, new(binding_t, .type=item_t, .code=CORD_cat("$", value)));
- return scope;
+ switch (iter_t->tag) {
+ case ArrayType: {
+ type_t *item_t = Match(iter_t, ArrayType)->item_type;
+ const char *vars[2] = {};
+ int64_t num_vars = 0;
+ for (ast_list_t *var = for_->vars; var; var = var->next) {
+ if (num_vars >= 2)
+ code_err(var->ast, "This is too many variables for this loop");
+ vars[num_vars++] = Match(var->ast, Var)->name;
}
- case TableType: {
- type_t *key_t = Match(iter_t, TableType)->key_type;
- type_t *value_t = Match(iter_t, TableType)->value_type;
- set_binding(scope, index, new(binding_t, .type=key_t, .code=CORD_cat("$", index)));
- set_binding(scope, value, new(binding_t, .type=value_t, .code=CORD_cat("$", value)));
- return scope;
- }
- case IntType: {
- set_binding(scope, index, new(binding_t, .type=Type(IntType, .bits=64), .code=CORD_cat("$", index)));
- set_binding(scope, value, new(binding_t, .type=iter_t, .code=CORD_cat("$", value)));
- return scope;
- }
- default: code_err(for_->iter, "Iteration is not implemented for type: %T", iter_t);
+ if (num_vars == 1) {
+ set_binding(scope, vars[0], new(binding_t, .type=item_t, .code=CORD_cat("$", vars[0])));
+ } else if (num_vars == 2) {
+ set_binding(scope, vars[0], new(binding_t, .type=Type(IntType, .bits=64), .code=CORD_cat("$", vars[0])));
+ set_binding(scope, vars[1], new(binding_t, .type=item_t, .code=CORD_cat("$", vars[1])));
}
- } else {
- switch (iter_t->tag) {
- case ArrayType: {
- type_t *item_t = Match(iter_t, ArrayType)->item_type;
- set_binding(scope, value, new(binding_t, .type=item_t, .code=CORD_cat("$", value)));
- return scope;
+ return scope;
+ }
+ case TableType: {
+ const char *vars[2] = {};
+ int64_t num_vars = 0;
+ for (ast_list_t *var = for_->vars; var; var = var->next) {
+ if (num_vars >= 2)
+ code_err(var->ast, "This is too many variables for this loop");
+ vars[num_vars++] = Match(var->ast, Var)->name;
}
- case TableType: {
- type_t *key_t = Match(iter_t, TableType)->key_type;
- set_binding(scope, value, new(binding_t, .type=key_t, .code=CORD_cat("$", value)));
- return scope;
+
+ type_t *key_t = Match(iter_t, TableType)->key_type;
+ if (num_vars == 1) {
+ set_binding(scope, vars[0], new(binding_t, .type=key_t, .code=CORD_cat("$", vars[0])));
+ } else if (num_vars == 2) {
+ set_binding(scope, vars[0], new(binding_t, .type=key_t, .code=CORD_cat("$", vars[0])));
+ type_t *value_t = Match(iter_t, TableType)->value_type;
+ set_binding(scope, vars[1], new(binding_t, .type=value_t, .code=CORD_cat("$", vars[1])));
}
- case IntType: {
- set_binding(scope, value, new(binding_t, .type=iter_t, .code=CORD_cat("$", value)));
- return scope;
+ return scope;
+ }
+ case IntType: {
+ if (for_->vars) {
+ if (for_->vars->next)
+ code_err(for_->vars->next->ast, "This is too many variables for this loop");
+ const char *var = Match(for_->vars->ast, Var)->name;
+ set_binding(scope, var, new(binding_t, .type=Type(IntType, .bits=64), .code=CORD_cat("$", var)));
}
- default: code_err(for_->iter, "Iteration is not implemented for type: %T", iter_t);
+ return scope;
+ }
+ case FunctionType: case ClosureType: {
+ auto fn = iter_t->tag == ClosureType ? Match(Match(iter_t, ClosureType)->fn, FunctionType) : Match(iter_t, FunctionType);
+ arg_t *next_arg = fn->args;
+ for (ast_list_t *var = for_->vars; var; var = var->next) {
+ if (next_arg == NULL)
+ code_err(var->ast, "This is too many variables for this iterator function");
+ const char *name = Match(var->ast, Var)->name;
+ type_t *t = get_arg_type(env, next_arg);
+ if (t->tag != PointerType)
+ code_err(for_->iter, "This iterator has type %T, but I need all its arguments to be mutable stack pointers", iter_t);
+ auto ptr = Match(t, PointerType);
+ if (!ptr->is_stack || ptr->is_readonly)
+ code_err(for_->iter, "This iterator has type %T, but I need all its arguments to be mutable stack pointers", iter_t);
+ set_binding(scope, name, new(binding_t, .type=ptr->pointed, .code=CORD_cat("$", name)));
+ next_arg = next_arg->next;
}
+ if (next_arg)
+ code_err(ast, "There are not enough variables given for this loop with an iterator that has type %T", iter_t);
+ return scope;
+ }
+ default: code_err(for_->iter, "Iteration is not implemented for type: %T", iter_t);
}
}
diff --git a/environment.h b/environment.h
index a06f3d42..be87b857 100644
--- a/environment.h
+++ b/environment.h
@@ -28,7 +28,8 @@ typedef struct deferral_s {
typedef struct loop_ctx_s {
struct loop_ctx_s *next;
- const char *loop_name, *key_name, *value_name;
+ const char *loop_name;
+ ast_list_t *loop_vars;
deferral_t *deferred;
CORD skip_label, stop_label;
} loop_ctx_t;
diff --git a/parse.c b/parse.c
index f549c257..71f85e66 100644
--- a/parse.c
+++ b/parse.c
@@ -858,22 +858,25 @@ ast_t *parse_comprehension_suffix(parse_ctx_t *ctx, ast_t *expr) {
whitespace(&pos);
if (!match_word(&pos, "for")) return NULL;
- ast_t *key = expect(ctx, start, &pos, parse_var, "I expected an iteration variable for this 'for'");
- whitespace(&pos);
- ast_t *value = NULL;
- if (match(&pos, ",")) {
- value = expect(ctx, pos-1, &pos, parse_var, "I expected a variable after this comma");
- } else {
- value = key;
- key = NULL;
+ ast_list_t *vars = NULL;
+ for (;;) {
+ ast_t *var = optional(ctx, &pos, parse_var);
+ if (var)
+ vars = new(ast_list_t, .ast=var, .next=vars);
+
+ spaces(&pos);
+ if (!match(&pos, ","))
+ break;
}
+ REVERSE_LIST(vars);
+
expect_str(ctx, start, &pos, "in", "I expected an 'in' for this 'for'");
ast_t *iter = expect(ctx, start, &pos, parse_expr, "I expected an iterable value for this 'for'");
whitespace(&pos);
ast_t *filter = NULL;
if (match_word(&pos, "if"))
filter = expect(ctx, pos-2, &pos, parse_expr, "I expected a condition for this 'if'");
- return NewAST(ctx->file, start, pos, Comprehension, .expr=expr, .key=key, .value=value, .iter=iter, .filter=filter);
+ return NewAST(ctx->file, start, pos, Comprehension, .expr=expr, .vars=vars, .iter=iter, .filter=filter);
}
PARSER(parse_if) {
@@ -968,13 +971,21 @@ PARSER(parse_for) {
const char *start = pos;
if (!match_word(&pos, "for")) return NULL;
int64_t starting_indent = get_indent(ctx, pos);
- ast_t *index = expect(ctx, start, &pos, parse_var, "I expected an iteration variable for this 'for'");
spaces(&pos);
- ast_t *value = NULL;
- if (match(&pos, ",")) {
- value = expect(ctx, pos-1, &pos, parse_var, "I expected a variable after this comma");
+ ast_list_t *vars = NULL;
+ for (;;) {
+ ast_t *var = optional(ctx, &pos, parse_var);
+ if (var)
+ vars = new(ast_list_t, .ast=var, .next=vars);
+
+ spaces(&pos);
+ if (!match(&pos, ","))
+ break;
}
+
+ spaces(&pos);
expect_str(ctx, start, &pos, "in", "I expected an 'in' for this 'for'");
+
ast_t *iter = expect(ctx, start, &pos, parse_expr, "I expected an iterable value for this 'for'");
ast_t *body = expect(ctx, start, &pos, parse_block, "I expected a body for this 'for'");
@@ -985,7 +996,8 @@ PARSER(parse_for) {
pos = else_start;
empty = expect(ctx, pos, &pos, parse_block, "I expected a body for this 'else'");
}
- return NewAST(ctx->file, start, pos, For, .index=value ? index : NULL, .value=value ? value : index, .iter=iter, .body=body, .empty=empty);
+ REVERSE_LIST(vars);
+ return NewAST(ctx->file, start, pos, For, .vars=vars, .iter=iter, .body=body, .empty=empty);
}
PARSER(parse_do) {
diff --git a/typecheck.c b/typecheck.c
index 43e384f4..ba04ad53 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -509,7 +509,8 @@ type_t *get_type(env_t *env, ast_t *ast)
env_t *scope = env;
while (item_ast->tag == Comprehension) {
auto comp = Match(item_ast, Comprehension);
- scope = for_scope(scope, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value));
+ scope = for_scope(
+ scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars));
item_ast = comp->expr;
}
type_t *t2 = get_type(scope, item_ast);
@@ -541,7 +542,8 @@ type_t *get_type(env_t *env, ast_t *ast)
env_t *scope = env;
while (entry_ast->tag == Comprehension) {
auto comp = Match(entry_ast, Comprehension);
- scope = for_scope(scope, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value));
+ scope = for_scope(
+ scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars));
entry_ast = comp->expr;
}
@@ -573,7 +575,7 @@ type_t *get_type(env_t *env, ast_t *ast)
}
case Comprehension: {
auto comp = Match(ast, Comprehension);
- env_t *scope = for_scope(env, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value));
+ env_t *scope = for_scope(env, FakeAST(For, .iter=comp->iter, .vars=comp->vars));
if (comp->expr->tag == Comprehension) {
return get_type(scope, comp->expr);
} else if (comp->expr->tag == TableEntry) {