From a0faef71028f78bde11233d76673ce8da2f5f783 Mon Sep 17 00:00:00 2001 From: Bruce Hill Date: Sun, 17 Mar 2024 15:26:25 -0400 Subject: Support nested comprehensions --- typecheck.c | 37 +++++++++++++++++-------------------- 1 file changed, 17 insertions(+), 20 deletions(-) (limited to 'typecheck.c') diff --git a/typecheck.c b/typecheck.c index 228029b8..2d212f76 100644 --- a/typecheck.c +++ b/typecheck.c @@ -307,15 +307,14 @@ type_t *get_type(env_t *env, ast_t *ast) item_type = parse_type_ast(env, array->type); } else if (array->items) { for (ast_list_t *item = array->items; item; item = item->next) { - type_t *t2; - if (item->ast->tag == Comprehension) { - auto comp = Match(item->ast, Comprehension); - - env_t *scope = for_scope(env, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value)); - t2 = get_type(scope, comp->expr); - } else { - t2 = get_type(env, item->ast); + ast_t *item_ast = item->ast; + env_t *scope = env; + while (item_ast->tag == Comprehension) { + auto comp = Match(item_ast, Comprehension); + scope = for_scope(scope, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value)); + item_ast = comp->expr; } + type_t *t2 = get_type(scope, item_ast); type_t *merged = item_type ? type_or_type(item_type, t2) : t2; if (!merged) code_err(item->ast, @@ -340,20 +339,18 @@ type_t *get_type(env_t *env, ast_t *ast) if (table->default_value) value_type = get_type(env, table->default_value); for (ast_list_t *entry = table->entries; entry; entry = entry->next) { - type_t *key_t, *value_t; - if (entry->ast->tag == Comprehension) { - auto comp = Match(entry->ast, Comprehension); - env_t *scope = for_scope(env, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value)); - if (comp->expr->tag != TableEntry) - code_err(comp->expr, "I expected this table comprehension to have a key/value entry"); - key_t = get_type(scope, Match(comp->expr, TableEntry)->key); - value_t = get_type(scope, Match(comp->expr, TableEntry)->value); - } else { - auto e = Match(entry->ast, TableEntry); - key_t = get_type(env, e->key); - value_t = get_type(env, e->value); + ast_t *entry_ast = entry->ast; + env_t *scope = env; + while (entry_ast->tag == Comprehension) { + auto comp = Match(entry_ast, Comprehension); + scope = for_scope(scope, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value)); + entry_ast = comp->expr; } + auto e = Match(entry_ast, TableEntry); + type_t *key_t = get_type(scope, e->key); + type_t *value_t = get_type(scope, e->value); + type_t *key_merged = key_type ? type_or_type(key_type, key_t) : key_t; if (!key_merged) code_err(entry->ast, -- cgit v1.2.3