diff options
| author | Bruce Hill <bruce@bruce-hill.com> | 2024-03-17 15:26:25 -0400 |
|---|---|---|
| committer | Bruce Hill <bruce@bruce-hill.com> | 2024-03-17 15:26:25 -0400 |
| commit | a0faef71028f78bde11233d76673ce8da2f5f783 (patch) | |
| tree | b5924af3a992536052993f85dd439a9eb9e98d3f | |
| parent | 1647fb4bed03fe5e3d0332049c73fb21a48ef05f (diff) | |
Support nested comprehensions
| -rw-r--r-- | compile.c | 69 | ||||
| -rw-r--r-- | environment.h | 1 | ||||
| -rw-r--r-- | typecheck.c | 37 |
3 files changed, 59 insertions, 48 deletions
@@ -656,6 +656,30 @@ CORD compile_statement(env_t *env, ast_t *ast) } return CORD_cat(code, "}"); } + case Comprehension: { + auto comp = Match(ast, Comprehension); + assert(env->comprehension_var); + if (comp->expr->tag == Comprehension) { // Nested comprehension + ast_t *body = comp->filter ? WrapAST(ast, If, .condition=comp->filter, .body=comp->expr) : comp->expr; + ast_t *loop = WrapAST(ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body); + return compile_statement(env, loop); + } else if (comp->expr->tag == TableEntry) { // Table comprehension + auto e = Match(comp->expr, TableEntry); + ast_t *body = WrapAST(ast, MethodCall, .name="set", .self=FakeAST(StackReference, FakeAST(Var, env->comprehension_var)), + .args=new(arg_ast_t, .value=e->key, .next=new(arg_ast_t, .value=e->value))); + if (comp->filter) + body = WrapAST(body, If, .condition=comp->filter, .body=body); + ast_t *loop = WrapAST(ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body); + return compile_statement(env, loop); + } else { // Array comprehension + ast_t *body = WrapAST(comp->expr, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, env->comprehension_var)), + .args=new(arg_ast_t, .value=comp->expr)); + if (comp->filter) + body = WrapAST(body, If, .condition=comp->filter, .body=body); + ast_t *loop = WrapAST(ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body); + return compile_statement(env, loop); + } + } case InlineCCode: return Match(ast, InlineCCode)->code; default: return CORD_asprintf("(void)%r;", compile(env, ast)); @@ -1167,26 +1191,22 @@ CORD compile(env_t *env, ast_t *ast) array_comprehension: { - CORD code = "({ array_t $arr = {};"; env_t *scope = fresh_scope(env); - set_binding(scope, "$arr", new(binding_t, .type=array_type, .code="$arr")); + static int64_t comp_num = 1; + scope->comprehension_var = heap_strf("$arr$%ld", comp_num++); + CORD code = CORD_all("({ array_t ", scope->comprehension_var, " = {};"); + set_binding(scope, scope->comprehension_var, new(binding_t, .type=array_type, .code=scope->comprehension_var)); for (ast_list_t *item = array->items; item; item = item->next) { if (item->ast->tag == Comprehension) { - auto comp = Match(item->ast, Comprehension); - ast_t *body = WrapAST(comp->expr, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, "$arr")), - .args=new(arg_ast_t, .value=comp->expr)); - if (comp->filter) - body = WrapAST(body, If, .condition=comp->filter, .body=body); - ast_t *loop = WrapAST(item->ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body); - code = CORD_all(code, "\n", compile_statement(scope, loop)); + code = CORD_all(code, "\n", compile_statement(scope, item->ast)); } else { CORD insert = compile_statement( - scope, WrapAST(item->ast, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, "$arr")), + scope, WrapAST(item->ast, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, scope->comprehension_var)), .args=new(arg_ast_t, .value=item->ast))); code = CORD_all(code, "\n", insert); } } - code = CORD_cat(code, " $arr; })"); + code = CORD_all(code, " ", scope->comprehension_var, "; })"); return code; } } @@ -1240,7 +1260,11 @@ CORD compile(env_t *env, ast_t *ast) table_comprehension: { - CORD code = "({ table_t $t = {"; + static int64_t comp_num = 1; + env_t *scope = fresh_scope(env); + scope->comprehension_var = heap_strf("$table$%ld", comp_num++); + + CORD code = CORD_all("({ table_t ", scope->comprehension_var, " = {"); if (table->fallback) code = CORD_all(code, ".fallback=$heap(", compile(env, table->fallback), "), "); @@ -1248,36 +1272,25 @@ CORD compile(env_t *env, ast_t *ast) code = CORD_all(code, ".default_value=$heap(", compile(env, table->default_value), "), "); code = CORD_cat(code, "};"); - env_t *scope = fresh_scope(env); - set_binding(scope, "$t", new(binding_t, .type=table_type, .code="$t")); + set_binding(scope, scope->comprehension_var, new(binding_t, .type=table_type, .code=scope->comprehension_var)); for (ast_list_t *entry = table->entries; entry; entry = entry->next) { if (entry->ast->tag == Comprehension) { - auto comp = Match(entry->ast, Comprehension); - auto e = Match(comp->expr, TableEntry); - ast_t *body = WrapAST(comp->expr, MethodCall, .name="set", .self=FakeAST(StackReference, FakeAST(Var, "$t")), - .args=new(arg_ast_t, .value=e->key, .next=new(arg_ast_t, .value=e->value))); - if (comp->filter) - body = WrapAST(body, If, .condition=comp->filter, .body=body); - ast_t *loop = WrapAST(entry->ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body); - code = CORD_all(code, "\n", compile_statement(scope, loop)); + code = CORD_all(code, "\n", compile_statement(scope, entry->ast)); } else { auto e = Match(entry->ast, TableEntry); CORD set = compile_statement( - scope, WrapAST(entry->ast, MethodCall, .name="set", .self=FakeAST(StackReference, FakeAST(Var, "$arr")), + scope, WrapAST(entry->ast, MethodCall, .name="set", .self=FakeAST(StackReference, FakeAST(Var, scope->comprehension_var)), .args=new(arg_ast_t, .value=e->key, .next=new(arg_ast_t, .value=e->value)))); code = CORD_all(code, "\n", set); } } - code = CORD_cat(code, " $t; })"); + code = CORD_all(code, " ", scope->comprehension_var, "; })"); return code; } } case Comprehension: { - auto comp = Match(ast, Comprehension); - ast_t *collection = comp->expr->tag == TableEntry ? WrapAST(ast, Table, .entries=new(ast_list_t, .ast=ast)) - : WrapAST(ast, Array, .items=new(ast_list_t, .ast=ast)); - return compile(env, collection); + code_err(ast, "Comprehensions cannot be compiled as expressions"); } case Lambda: { auto lambda = Match(ast, Lambda); diff --git a/environment.h b/environment.h index 7beb7e12..49d15a19 100644 --- a/environment.h +++ b/environment.h @@ -35,6 +35,7 @@ typedef struct { fn_ctx_t *fn_ctx; loop_ctx_t *loop_ctx; CORD scope_prefix; + const char *comprehension_var; } env_t; typedef struct { diff --git a/typecheck.c b/typecheck.c index 228029b8..2d212f76 100644 --- a/typecheck.c +++ b/typecheck.c @@ -307,15 +307,14 @@ type_t *get_type(env_t *env, ast_t *ast) item_type = parse_type_ast(env, array->type); } else if (array->items) { for (ast_list_t *item = array->items; item; item = item->next) { - type_t *t2; - if (item->ast->tag == Comprehension) { - auto comp = Match(item->ast, Comprehension); - - env_t *scope = for_scope(env, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value)); - t2 = get_type(scope, comp->expr); - } else { - t2 = get_type(env, item->ast); + ast_t *item_ast = item->ast; + env_t *scope = env; + while (item_ast->tag == Comprehension) { + auto comp = Match(item_ast, Comprehension); + scope = for_scope(scope, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value)); + item_ast = comp->expr; } + type_t *t2 = get_type(scope, item_ast); type_t *merged = item_type ? type_or_type(item_type, t2) : t2; if (!merged) code_err(item->ast, @@ -340,20 +339,18 @@ type_t *get_type(env_t *env, ast_t *ast) if (table->default_value) value_type = get_type(env, table->default_value); for (ast_list_t *entry = table->entries; entry; entry = entry->next) { - type_t *key_t, *value_t; - if (entry->ast->tag == Comprehension) { - auto comp = Match(entry->ast, Comprehension); - env_t *scope = for_scope(env, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value)); - if (comp->expr->tag != TableEntry) - code_err(comp->expr, "I expected this table comprehension to have a key/value entry"); - key_t = get_type(scope, Match(comp->expr, TableEntry)->key); - value_t = get_type(scope, Match(comp->expr, TableEntry)->value); - } else { - auto e = Match(entry->ast, TableEntry); - key_t = get_type(env, e->key); - value_t = get_type(env, e->value); + ast_t *entry_ast = entry->ast; + env_t *scope = env; + while (entry_ast->tag == Comprehension) { + auto comp = Match(entry_ast, Comprehension); + scope = for_scope(scope, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value)); + entry_ast = comp->expr; } + auto e = Match(entry_ast, TableEntry); + type_t *key_t = get_type(scope, e->key); + type_t *value_t = get_type(scope, e->value); + type_t *key_merged = key_type ? type_or_type(key_type, key_t) : key_t; if (!key_merged) code_err(entry->ast, |
