aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compile.c69
-rw-r--r--environment.h1
-rw-r--r--typecheck.c37
3 files changed, 59 insertions, 48 deletions
diff --git a/compile.c b/compile.c
index 1032f61e..2f2bc4e3 100644
--- a/compile.c
+++ b/compile.c
@@ -656,6 +656,30 @@ CORD compile_statement(env_t *env, ast_t *ast)
}
return CORD_cat(code, "}");
}
+ case Comprehension: {
+ auto comp = Match(ast, Comprehension);
+ assert(env->comprehension_var);
+ if (comp->expr->tag == Comprehension) { // Nested comprehension
+ ast_t *body = comp->filter ? WrapAST(ast, If, .condition=comp->filter, .body=comp->expr) : comp->expr;
+ ast_t *loop = WrapAST(ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body);
+ return compile_statement(env, loop);
+ } else if (comp->expr->tag == TableEntry) { // Table comprehension
+ auto e = Match(comp->expr, TableEntry);
+ ast_t *body = WrapAST(ast, MethodCall, .name="set", .self=FakeAST(StackReference, FakeAST(Var, env->comprehension_var)),
+ .args=new(arg_ast_t, .value=e->key, .next=new(arg_ast_t, .value=e->value)));
+ if (comp->filter)
+ body = WrapAST(body, If, .condition=comp->filter, .body=body);
+ ast_t *loop = WrapAST(ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body);
+ return compile_statement(env, loop);
+ } else { // Array comprehension
+ ast_t *body = WrapAST(comp->expr, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, env->comprehension_var)),
+ .args=new(arg_ast_t, .value=comp->expr));
+ if (comp->filter)
+ body = WrapAST(body, If, .condition=comp->filter, .body=body);
+ ast_t *loop = WrapAST(ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body);
+ return compile_statement(env, loop);
+ }
+ }
case InlineCCode: return Match(ast, InlineCCode)->code;
default:
return CORD_asprintf("(void)%r;", compile(env, ast));
@@ -1167,26 +1191,22 @@ CORD compile(env_t *env, ast_t *ast)
array_comprehension:
{
- CORD code = "({ array_t $arr = {};";
env_t *scope = fresh_scope(env);
- set_binding(scope, "$arr", new(binding_t, .type=array_type, .code="$arr"));
+ static int64_t comp_num = 1;
+ scope->comprehension_var = heap_strf("$arr$%ld", comp_num++);
+ CORD code = CORD_all("({ array_t ", scope->comprehension_var, " = {};");
+ set_binding(scope, scope->comprehension_var, new(binding_t, .type=array_type, .code=scope->comprehension_var));
for (ast_list_t *item = array->items; item; item = item->next) {
if (item->ast->tag == Comprehension) {
- auto comp = Match(item->ast, Comprehension);
- ast_t *body = WrapAST(comp->expr, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, "$arr")),
- .args=new(arg_ast_t, .value=comp->expr));
- if (comp->filter)
- body = WrapAST(body, If, .condition=comp->filter, .body=body);
- ast_t *loop = WrapAST(item->ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body);
- code = CORD_all(code, "\n", compile_statement(scope, loop));
+ code = CORD_all(code, "\n", compile_statement(scope, item->ast));
} else {
CORD insert = compile_statement(
- scope, WrapAST(item->ast, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, "$arr")),
+ scope, WrapAST(item->ast, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, scope->comprehension_var)),
.args=new(arg_ast_t, .value=item->ast)));
code = CORD_all(code, "\n", insert);
}
}
- code = CORD_cat(code, " $arr; })");
+ code = CORD_all(code, " ", scope->comprehension_var, "; })");
return code;
}
}
@@ -1240,7 +1260,11 @@ CORD compile(env_t *env, ast_t *ast)
table_comprehension:
{
- CORD code = "({ table_t $t = {";
+ static int64_t comp_num = 1;
+ env_t *scope = fresh_scope(env);
+ scope->comprehension_var = heap_strf("$table$%ld", comp_num++);
+
+ CORD code = CORD_all("({ table_t ", scope->comprehension_var, " = {");
if (table->fallback)
code = CORD_all(code, ".fallback=$heap(", compile(env, table->fallback), "), ");
@@ -1248,36 +1272,25 @@ CORD compile(env_t *env, ast_t *ast)
code = CORD_all(code, ".default_value=$heap(", compile(env, table->default_value), "), ");
code = CORD_cat(code, "};");
- env_t *scope = fresh_scope(env);
- set_binding(scope, "$t", new(binding_t, .type=table_type, .code="$t"));
+ set_binding(scope, scope->comprehension_var, new(binding_t, .type=table_type, .code=scope->comprehension_var));
for (ast_list_t *entry = table->entries; entry; entry = entry->next) {
if (entry->ast->tag == Comprehension) {
- auto comp = Match(entry->ast, Comprehension);
- auto e = Match(comp->expr, TableEntry);
- ast_t *body = WrapAST(comp->expr, MethodCall, .name="set", .self=FakeAST(StackReference, FakeAST(Var, "$t")),
- .args=new(arg_ast_t, .value=e->key, .next=new(arg_ast_t, .value=e->value)));
- if (comp->filter)
- body = WrapAST(body, If, .condition=comp->filter, .body=body);
- ast_t *loop = WrapAST(entry->ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body);
- code = CORD_all(code, "\n", compile_statement(scope, loop));
+ code = CORD_all(code, "\n", compile_statement(scope, entry->ast));
} else {
auto e = Match(entry->ast, TableEntry);
CORD set = compile_statement(
- scope, WrapAST(entry->ast, MethodCall, .name="set", .self=FakeAST(StackReference, FakeAST(Var, "$arr")),
+ scope, WrapAST(entry->ast, MethodCall, .name="set", .self=FakeAST(StackReference, FakeAST(Var, scope->comprehension_var)),
.args=new(arg_ast_t, .value=e->key, .next=new(arg_ast_t, .value=e->value))));
code = CORD_all(code, "\n", set);
}
}
- code = CORD_cat(code, " $t; })");
+ code = CORD_all(code, " ", scope->comprehension_var, "; })");
return code;
}
}
case Comprehension: {
- auto comp = Match(ast, Comprehension);
- ast_t *collection = comp->expr->tag == TableEntry ? WrapAST(ast, Table, .entries=new(ast_list_t, .ast=ast))
- : WrapAST(ast, Array, .items=new(ast_list_t, .ast=ast));
- return compile(env, collection);
+ code_err(ast, "Comprehensions cannot be compiled as expressions");
}
case Lambda: {
auto lambda = Match(ast, Lambda);
diff --git a/environment.h b/environment.h
index 7beb7e12..49d15a19 100644
--- a/environment.h
+++ b/environment.h
@@ -35,6 +35,7 @@ typedef struct {
fn_ctx_t *fn_ctx;
loop_ctx_t *loop_ctx;
CORD scope_prefix;
+ const char *comprehension_var;
} env_t;
typedef struct {
diff --git a/typecheck.c b/typecheck.c
index 228029b8..2d212f76 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -307,15 +307,14 @@ type_t *get_type(env_t *env, ast_t *ast)
item_type = parse_type_ast(env, array->type);
} else if (array->items) {
for (ast_list_t *item = array->items; item; item = item->next) {
- type_t *t2;
- if (item->ast->tag == Comprehension) {
- auto comp = Match(item->ast, Comprehension);
-
- env_t *scope = for_scope(env, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value));
- t2 = get_type(scope, comp->expr);
- } else {
- t2 = get_type(env, item->ast);
+ ast_t *item_ast = item->ast;
+ env_t *scope = env;
+ while (item_ast->tag == Comprehension) {
+ auto comp = Match(item_ast, Comprehension);
+ scope = for_scope(scope, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value));
+ item_ast = comp->expr;
}
+ type_t *t2 = get_type(scope, item_ast);
type_t *merged = item_type ? type_or_type(item_type, t2) : t2;
if (!merged)
code_err(item->ast,
@@ -340,20 +339,18 @@ type_t *get_type(env_t *env, ast_t *ast)
if (table->default_value)
value_type = get_type(env, table->default_value);
for (ast_list_t *entry = table->entries; entry; entry = entry->next) {
- type_t *key_t, *value_t;
- if (entry->ast->tag == Comprehension) {
- auto comp = Match(entry->ast, Comprehension);
- env_t *scope = for_scope(env, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value));
- if (comp->expr->tag != TableEntry)
- code_err(comp->expr, "I expected this table comprehension to have a key/value entry");
- key_t = get_type(scope, Match(comp->expr, TableEntry)->key);
- value_t = get_type(scope, Match(comp->expr, TableEntry)->value);
- } else {
- auto e = Match(entry->ast, TableEntry);
- key_t = get_type(env, e->key);
- value_t = get_type(env, e->value);
+ ast_t *entry_ast = entry->ast;
+ env_t *scope = env;
+ while (entry_ast->tag == Comprehension) {
+ auto comp = Match(entry_ast, Comprehension);
+ scope = for_scope(scope, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value));
+ entry_ast = comp->expr;
}
+ auto e = Match(entry_ast, TableEntry);
+ type_t *key_t = get_type(scope, e->key);
+ type_t *value_t = get_type(scope, e->value);
+
type_t *key_merged = key_type ? type_or_type(key_type, key_t) : key_t;
if (!key_merged)
code_err(entry->ast,