Support nested comprehensions

This commit is contained in:
Bruce Hill 2024-03-17 15:26:25 -04:00
parent 1647fb4bed
commit a0faef7102
3 changed files with 59 additions and 48 deletions

View File

@ -656,6 +656,30 @@ CORD compile_statement(env_t *env, ast_t *ast)
}
return CORD_cat(code, "}");
}
case Comprehension: {
auto comp = Match(ast, Comprehension);
assert(env->comprehension_var);
if (comp->expr->tag == Comprehension) { // Nested comprehension
ast_t *body = comp->filter ? WrapAST(ast, If, .condition=comp->filter, .body=comp->expr) : comp->expr;
ast_t *loop = WrapAST(ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body);
return compile_statement(env, loop);
} else if (comp->expr->tag == TableEntry) { // Table comprehension
auto e = Match(comp->expr, TableEntry);
ast_t *body = WrapAST(ast, MethodCall, .name="set", .self=FakeAST(StackReference, FakeAST(Var, env->comprehension_var)),
.args=new(arg_ast_t, .value=e->key, .next=new(arg_ast_t, .value=e->value)));
if (comp->filter)
body = WrapAST(body, If, .condition=comp->filter, .body=body);
ast_t *loop = WrapAST(ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body);
return compile_statement(env, loop);
} else { // Array comprehension
ast_t *body = WrapAST(comp->expr, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, env->comprehension_var)),
.args=new(arg_ast_t, .value=comp->expr));
if (comp->filter)
body = WrapAST(body, If, .condition=comp->filter, .body=body);
ast_t *loop = WrapAST(ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body);
return compile_statement(env, loop);
}
}
case InlineCCode: return Match(ast, InlineCCode)->code;
default:
return CORD_asprintf("(void)%r;", compile(env, ast));
@ -1167,26 +1191,22 @@ CORD compile(env_t *env, ast_t *ast)
array_comprehension:
{
CORD code = "({ array_t $arr = {};";
env_t *scope = fresh_scope(env);
set_binding(scope, "$arr", new(binding_t, .type=array_type, .code="$arr"));
static int64_t comp_num = 1;
scope->comprehension_var = heap_strf("$arr$%ld", comp_num++);
CORD code = CORD_all("({ array_t ", scope->comprehension_var, " = {};");
set_binding(scope, scope->comprehension_var, new(binding_t, .type=array_type, .code=scope->comprehension_var));
for (ast_list_t *item = array->items; item; item = item->next) {
if (item->ast->tag == Comprehension) {
auto comp = Match(item->ast, Comprehension);
ast_t *body = WrapAST(comp->expr, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, "$arr")),
.args=new(arg_ast_t, .value=comp->expr));
if (comp->filter)
body = WrapAST(body, If, .condition=comp->filter, .body=body);
ast_t *loop = WrapAST(item->ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body);
code = CORD_all(code, "\n", compile_statement(scope, loop));
code = CORD_all(code, "\n", compile_statement(scope, item->ast));
} else {
CORD insert = compile_statement(
scope, WrapAST(item->ast, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, "$arr")),
scope, WrapAST(item->ast, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, scope->comprehension_var)),
.args=new(arg_ast_t, .value=item->ast)));
code = CORD_all(code, "\n", insert);
}
}
code = CORD_cat(code, " $arr; })");
code = CORD_all(code, " ", scope->comprehension_var, "; })");
return code;
}
}
@ -1240,7 +1260,11 @@ CORD compile(env_t *env, ast_t *ast)
table_comprehension:
{
CORD code = "({ table_t $t = {";
static int64_t comp_num = 1;
env_t *scope = fresh_scope(env);
scope->comprehension_var = heap_strf("$table$%ld", comp_num++);
CORD code = CORD_all("({ table_t ", scope->comprehension_var, " = {");
if (table->fallback)
code = CORD_all(code, ".fallback=$heap(", compile(env, table->fallback), "), ");
@ -1248,36 +1272,25 @@ CORD compile(env_t *env, ast_t *ast)
code = CORD_all(code, ".default_value=$heap(", compile(env, table->default_value), "), ");
code = CORD_cat(code, "};");
env_t *scope = fresh_scope(env);
set_binding(scope, "$t", new(binding_t, .type=table_type, .code="$t"));
set_binding(scope, scope->comprehension_var, new(binding_t, .type=table_type, .code=scope->comprehension_var));
for (ast_list_t *entry = table->entries; entry; entry = entry->next) {
if (entry->ast->tag == Comprehension) {
auto comp = Match(entry->ast, Comprehension);
auto e = Match(comp->expr, TableEntry);
ast_t *body = WrapAST(comp->expr, MethodCall, .name="set", .self=FakeAST(StackReference, FakeAST(Var, "$t")),
.args=new(arg_ast_t, .value=e->key, .next=new(arg_ast_t, .value=e->value)));
if (comp->filter)
body = WrapAST(body, If, .condition=comp->filter, .body=body);
ast_t *loop = WrapAST(entry->ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body);
code = CORD_all(code, "\n", compile_statement(scope, loop));
code = CORD_all(code, "\n", compile_statement(scope, entry->ast));
} else {
auto e = Match(entry->ast, TableEntry);
CORD set = compile_statement(
scope, WrapAST(entry->ast, MethodCall, .name="set", .self=FakeAST(StackReference, FakeAST(Var, "$arr")),
scope, WrapAST(entry->ast, MethodCall, .name="set", .self=FakeAST(StackReference, FakeAST(Var, scope->comprehension_var)),
.args=new(arg_ast_t, .value=e->key, .next=new(arg_ast_t, .value=e->value))));
code = CORD_all(code, "\n", set);
}
}
code = CORD_cat(code, " $t; })");
code = CORD_all(code, " ", scope->comprehension_var, "; })");
return code;
}
}
case Comprehension: {
auto comp = Match(ast, Comprehension);
ast_t *collection = comp->expr->tag == TableEntry ? WrapAST(ast, Table, .entries=new(ast_list_t, .ast=ast))
: WrapAST(ast, Array, .items=new(ast_list_t, .ast=ast));
return compile(env, collection);
code_err(ast, "Comprehensions cannot be compiled as expressions");
}
case Lambda: {
auto lambda = Match(ast, Lambda);

View File

@ -35,6 +35,7 @@ typedef struct {
fn_ctx_t *fn_ctx;
loop_ctx_t *loop_ctx;
CORD scope_prefix;
const char *comprehension_var;
} env_t;
typedef struct {

View File

@ -307,15 +307,14 @@ type_t *get_type(env_t *env, ast_t *ast)
item_type = parse_type_ast(env, array->type);
} else if (array->items) {
for (ast_list_t *item = array->items; item; item = item->next) {
type_t *t2;
if (item->ast->tag == Comprehension) {
auto comp = Match(item->ast, Comprehension);
env_t *scope = for_scope(env, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value));
t2 = get_type(scope, comp->expr);
} else {
t2 = get_type(env, item->ast);
ast_t *item_ast = item->ast;
env_t *scope = env;
while (item_ast->tag == Comprehension) {
auto comp = Match(item_ast, Comprehension);
scope = for_scope(scope, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value));
item_ast = comp->expr;
}
type_t *t2 = get_type(scope, item_ast);
type_t *merged = item_type ? type_or_type(item_type, t2) : t2;
if (!merged)
code_err(item->ast,
@ -340,20 +339,18 @@ type_t *get_type(env_t *env, ast_t *ast)
if (table->default_value)
value_type = get_type(env, table->default_value);
for (ast_list_t *entry = table->entries; entry; entry = entry->next) {
type_t *key_t, *value_t;
if (entry->ast->tag == Comprehension) {
auto comp = Match(entry->ast, Comprehension);
env_t *scope = for_scope(env, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value));
if (comp->expr->tag != TableEntry)
code_err(comp->expr, "I expected this table comprehension to have a key/value entry");
key_t = get_type(scope, Match(comp->expr, TableEntry)->key);
value_t = get_type(scope, Match(comp->expr, TableEntry)->value);
} else {
auto e = Match(entry->ast, TableEntry);
key_t = get_type(env, e->key);
value_t = get_type(env, e->value);
ast_t *entry_ast = entry->ast;
env_t *scope = env;
while (entry_ast->tag == Comprehension) {
auto comp = Match(entry_ast, Comprehension);
scope = for_scope(scope, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value));
entry_ast = comp->expr;
}
auto e = Match(entry_ast, TableEntry);
type_t *key_t = get_type(scope, e->key);
type_t *value_t = get_type(scope, e->value);
type_t *key_merged = key_type ? type_or_type(key_type, key_t) : key_t;
if (!key_merged)
code_err(entry->ast,