diff options
| author | Bruce Hill <bruce@bruce-hill.com> | 2024-03-17 14:46:36 -0400 |
|---|---|---|
| committer | Bruce Hill <bruce@bruce-hill.com> | 2024-03-17 14:46:36 -0400 |
| commit | 993284153011006a1164b4b1f6bb1522e5131cb0 (patch) | |
| tree | b34793189363b6b314d2e69f2c7b1ea6249576c2 | |
| parent | 5c2bb00bafa4ad6e004f171687b9f21a824695c6 (diff) | |
Improve comprehensions for both arrays and tables
| -rw-r--r-- | ast.c | 2 | ||||
| -rw-r--r-- | ast.h | 5 | ||||
| -rw-r--r-- | builtins/table.h | 2 | ||||
| -rw-r--r-- | compile.c | 115 | ||||
| -rw-r--r-- | parse.c | 46 | ||||
| -rw-r--r-- | test/tables.tm | 5 | ||||
| -rw-r--r-- | typecheck.c | 34 |
7 files changed, 137 insertions, 72 deletions
@@ -116,6 +116,8 @@ CORD ast_to_cord(ast_t *ast) ast_to_cord(data.fallback), ast_to_cord(data.default_value), ast_list_to_cord(data.entries)) T(TableEntry, "(%r => %r)", ast_to_cord(data.key), ast_to_cord(data.value)) + T(Comprehension, "(expr=%r, key=%r, value=%r, iter=%r, filter=%r)", ast_to_cord(data.expr), + ast_to_cord(data.key), ast_to_cord(data.value), ast_to_cord(data.iter), ast_to_cord(data.filter)) T(FunctionDef, "(name=%r, args=%r, ret=%r, body=%r)", ast_to_cord(data.name), arg_list_to_cord(data.args), type_ast_to_cord(data.ret_type), ast_to_cord(data.body)) T(Lambda, "(args=%r, body=%r)", arg_list_to_cord(data.args), ast_to_cord(data.body)) @@ -96,7 +96,7 @@ typedef enum { BinaryOp, UpdateAssign, Length, Not, Negative, HeapAllocate, StackReference, Min, Max, - Array, Table, TableEntry, + Array, Table, TableEntry, Comprehension, FunctionDef, Lambda, FunctionCall, MethodCall, Block, @@ -174,6 +174,9 @@ struct ast_s { ast_t *key, *value; } TableEntry; struct { + ast_t *expr, *key, *value, *iter, *filter; + } Comprehension; + struct { ast_t *name; arg_ast_t *args; type_ast_t *ret_type; diff --git a/builtins/table.h b/builtins/table.h index 49f2517e..6577782d 100644 --- a/builtins/table.h +++ b/builtins/table.h @@ -43,6 +43,8 @@ void *Table_get_raw(table_t t, const void *key, const TypeInfo *type); void *Table_entry(table_t t, int64_t n); void *Table_reserve(table_t *t, const void *key, const void *value, const TypeInfo *type); void Table_set(table_t *t, const void *key, const void *value, const TypeInfo *type); +#define Table_set_value(t, key_expr, value_expr, type) ({ __typeof(key_expr) $k = key_expr; __typeof(value_expr) $v = value_expr; \ + Table_set(t, &$k, &$v, type); }) void Table_remove(table_t *t, const void *key, const TypeInfo *type); void Table_clear(table_t *t); void Table_mark_copy_on_write(table_t *t); @@ -1153,7 +1153,7 @@ CORD compile(env_t *env, ast_t *ast) int64_t n = 0; for (ast_list_t *item = array->items; item; item = item->next) { ++n; - if (item->ast->tag == For) + if (item->ast->tag == Comprehension) goto array_comprehension; } @@ -1171,13 +1171,14 @@ CORD compile(env_t *env, ast_t *ast) env_t *scope = fresh_scope(env); set_binding(scope, "$arr", new(binding_t, .type=array_type, .code="$arr")); for (ast_list_t *item = array->items; item; item = item->next) { - if (item->ast->tag == For) { - auto for_ = Match(item->ast, For); - env_t *body_scope = for_scope(scope, item->ast); - ast_t *for2 = WrapAST(item->ast, For, .index=for_->index, .value=for_->value, .iter=for_->iter, - .body=WrapAST(for_->body, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, "$arr")), - .args=new(arg_ast_t, .value=for_->body))); - code = CORD_all(code, "\n", compile_statement(body_scope, for2)); + if (item->ast->tag == Comprehension) { + auto comp = Match(item->ast, Comprehension); + ast_t *body = WrapAST(comp->expr, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, "$arr")), + .args=new(arg_ast_t, .value=comp->expr)); + if (comp->filter) + body = WrapAST(body, If, .condition=comp->filter, .body=body); + ast_t *loop = WrapAST(item->ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body); + code = CORD_all(code, "\n", compile_statement(scope, loop)); } else { CORD insert = compile_statement( scope, WrapAST(item->ast, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, "$arr")), @@ -1199,37 +1200,85 @@ CORD compile(env_t *env, ast_t *ast) code = CORD_all(code, ".default_value=$heap(", compile(env, table->default_value),"),"); return CORD_cat(code, "}"); } - + type_t *table_type = get_type(env, ast); type_t *key_t = Match(table_type, TableType)->key_type; type_t *value_t = Match(table_type, TableType)->value_type; - CORD code = CORD_all("$Table(", - compile_type(key_t), ", ", - compile_type(value_t), ", ", - compile_type_info(env, key_t), ", ", - compile_type_info(env, value_t)); - if (table->fallback) - code = CORD_all(code, ", /*fallback:*/ $heap(", compile(env, table->fallback), ")"); - else - code = CORD_all(code, ", /*fallback:*/ NULL"); - if (table->default_value) - code = CORD_all(code, ", /*default:*/ $heap(", compile(env, table->default_value), ")"); - else - code = CORD_all(code, ", /*default:*/ NULL"); + for (ast_list_t *entry = table->entries; entry; entry = entry->next) { + if (entry->ast->tag == Comprehension) + goto table_comprehension; + } + + { // No comprehension: + CORD code = CORD_all("$Table(", + compile_type(key_t), ", ", + compile_type(value_t), ", ", + compile_type_info(env, key_t), ", ", + compile_type_info(env, value_t)); + if (table->fallback) + code = CORD_all(code, ", /*fallback:*/ $heap(", compile(env, table->fallback), ")"); + else + code = CORD_all(code, ", /*fallback:*/ NULL"); - size_t n = 0; - for (ast_list_t *entry = table->entries; entry; entry = entry->next) - ++n; - CORD_appendf(&code, ", %zu", n); + if (table->default_value) + code = CORD_all(code, ", /*default:*/ $heap(", compile(env, table->default_value), ")"); + else + code = CORD_all(code, ", /*default:*/ NULL"); - for (ast_list_t *entry = table->entries; entry; entry = entry->next) { - auto e = Match(entry->ast, TableEntry); - code = CORD_all(code, ",\n\t{", compile(env, e->key), ", ", compile(env, e->value), "}"); + size_t n = 0; + for (ast_list_t *entry = table->entries; entry; entry = entry->next) + ++n; + CORD_appendf(&code, ", %zu", n); + + for (ast_list_t *entry = table->entries; entry; entry = entry->next) { + auto e = Match(entry->ast, TableEntry); + code = CORD_all(code, ",\n\t{", compile(env, e->key), ", ", compile(env, e->value), "}"); + } + return CORD_cat(code, ")"); + } + + table_comprehension: + { + CORD code = "({ table_t $t = {"; + if (table->fallback) + code = CORD_all(code, ".fallback=$heap(", compile(env, table->fallback), "), "); + + if (table->default_value) + code = CORD_all(code, ".default_value=$heap(", compile(env, table->default_value), "), "); + code = CORD_cat(code, "};"); + + env_t *scope = fresh_scope(env); + set_binding(scope, "$t", new(binding_t, .type=table_type, .code="$t")); + for (ast_list_t *entry = table->entries; entry; entry = entry->next) { + if (entry->ast->tag == Comprehension) { + auto comp = Match(entry->ast, Comprehension); + auto e = Match(comp->expr, TableEntry); + ast_t *body = WrapAST(comp->expr, MethodCall, .name="set", .self=FakeAST(StackReference, FakeAST(Var, "$t")), + .args=new(arg_ast_t, .value=e->key, .next=new(arg_ast_t, .value=e->value))); + if (comp->filter) + body = WrapAST(body, If, .condition=comp->filter, .body=body); + ast_t *loop = WrapAST(entry->ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body); + code = CORD_all(code, "\n", compile_statement(scope, loop)); + } else { + auto e = Match(entry->ast, TableEntry); + CORD set = compile_statement( + scope, WrapAST(entry->ast, MethodCall, .name="set", .self=FakeAST(StackReference, FakeAST(Var, "$arr")), + .args=new(arg_ast_t, .value=e->key, .next=new(arg_ast_t, .value=e->value)))); + code = CORD_all(code, "\n", set); + } + } + code = CORD_cat(code, " $t; })"); + return code; } - return CORD_cat(code, ")"); } + case Comprehension: { + auto comp = Match(ast, Comprehension); + ast_t *collection = comp->expr->tag == TableEntry ? WrapAST(ast, Table, .entries=new(ast_list_t, .ast=ast)) + : WrapAST(ast, Array, .items=new(ast_list_t, .ast=ast)); + return compile(env, collection); + } case Lambda: { auto lambda = Match(ast, Lambda); static int64_t lambda_number = 1; @@ -1344,9 +1393,9 @@ CORD compile(env_t *env, ast_t *ast) compile_type_info(env, self_value_t), ")"); } else if (streq(call->name, "set")) { CORD self = compile_to_pointer_depth(env, call->self, 1, false); - arg_t *arg_spec = new(arg_t, .name="key", .type=Type(PointerType, .pointed=table->key_type, .is_stack=true, .is_readonly=true), - .next=new(arg_t, .name="value", .type=Type(PointerType, .pointed=table->value_type, .is_stack=true, .is_readonly=true))); - return CORD_all("Table_set(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ", + arg_t *arg_spec = new(arg_t, .name="key", .type=table->key_type, + .next=new(arg_t, .name="value", .type=table->value_type)); + return CORD_all("Table_set_value(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ", compile_type_info(env, self_value_t), ")"); } else if (streq(call->name, "remove")) { CORD self = compile_to_pointer_depth(env, call->self, 1, false); @@ -68,7 +68,7 @@ static ast_t *parse_fncall_suffix(parse_ctx_t *ctx, ast_t *fn, bool is_extern); static ast_t *parse_method_call_suffix(parse_ctx_t *ctx, ast_t *self); static ast_t *parse_field_suffix(parse_ctx_t *ctx, ast_t *lhs); static ast_t *parse_index_suffix(parse_ctx_t *ctx, ast_t *lhs); -static ast_t *parse_for_suffix(parse_ctx_t *ctx, ast_t *lhs); +static ast_t *parse_comprehension_suffix(parse_ctx_t *ctx, ast_t *lhs); static arg_ast_t *parse_args(parse_ctx_t *ctx, const char **pos, bool allow_unnamed); static PARSER(parse_for); static PARSER(parse_while); @@ -617,7 +617,7 @@ PARSER(parse_array) { for (;;) { ast_t *item = optional(ctx, &pos, parse_extended_expr); if (!item) break; - ast_t *suffixed = parse_for_suffix(ctx, item); + ast_t *suffixed = parse_comprehension_suffix(ctx, item); if (suffixed) { item = suffixed; pos = suffixed->end; @@ -661,20 +661,12 @@ PARSER(parse_table) { whitespace(&pos); if (!match(&pos, "=>")) return NULL; ast_t *value = expect(ctx, pos-1, &pos, parse_expr, "I couldn't parse the value for this table entry"); - ast_t *entry = NewAST(ctx->file, entry_start, pos, TableEntry, .key=key, .value=value); - for (bool progress = true; progress; ) { - ast_t *new_entry; - progress = (false - || (new_entry=parse_index_suffix(ctx, entry)) - || (new_entry=parse_field_suffix(ctx, entry)) - || (new_entry=parse_method_call_suffix(ctx, entry)) - || (new_entry=parse_fncall_suffix(ctx, entry, NORMAL_FUNCTION)) - ); - if (progress) entry = new_entry; + ast_t *suffixed = parse_comprehension_suffix(ctx, entry); + if (suffixed) { + entry = suffixed; + pos = suffixed->end; } - pos = entry->end; - entries = new(ast_list_t, .ast=entry, .next=entries); if (!match_separator(&pos)) break; @@ -791,34 +783,30 @@ ast_t *parse_index_suffix(parse_ctx_t *ctx, ast_t *lhs) { return NewAST(ctx->file, start, pos, Index, .indexed=lhs, .index=index, .unchecked=unchecked); } -ast_t *parse_for_suffix(parse_ctx_t *ctx, ast_t *lhs) { +ast_t *parse_comprehension_suffix(parse_ctx_t *ctx, ast_t *expr) { // <expr> for [<index>,]<var> in <iter> [if <condition>] - if (!lhs) return NULL; - const char *start = lhs->start; - const char *pos = lhs->end; + if (!expr) return NULL; + const char *start = expr->start; + const char *pos = expr->end; whitespace(&pos); if (!match_word(&pos, "for")) return NULL; - ast_t *index = expect(ctx, start, &pos, parse_var, "I expected an iteration variable for this 'for'"); + ast_t *key = expect(ctx, start, &pos, parse_var, "I expected an iteration variable for this 'for'"); whitespace(&pos); ast_t *value = NULL; if (match(&pos, ",")) { value = expect(ctx, pos-1, &pos, parse_var, "I expected a variable after this comma"); } else { - value = index; - index = NULL; + value = key; + key = NULL; } expect_str(ctx, start, &pos, "in", "I expected an 'in' for this 'for'"); ast_t *iter = expect(ctx, start, &pos, parse_expr, "I expected an iterable value for this 'for'"); whitespace(&pos); - ast_t *body = lhs; - if (match_word(&pos, "if")) { - ast_t *condition = expect(ctx, pos-2, &pos, parse_expr, "I expected a condition for this 'if'"); - body = NewAST(ctx->file, body->start, condition->end, Block, - .statements=new(ast_list_t, .ast=WrapAST(condition, If, .condition=FakeAST(Not, condition), .body=FakeAST(Skip)), - .next=new(ast_list_t, .ast=body))); - } - return NewAST(ctx->file, start, pos, For, .index=index, .value=value, .iter=iter, .body=body); + ast_t *filter = NULL; + if (match_word(&pos, "if")) + filter = expect(ctx, pos-2, &pos, parse_expr, "I expected a condition for this 'if'"); + return NewAST(ctx->file, start, pos, Comprehension, .expr=expr, .key=key, .value=value, .iter=iter, .filter=filter); } PARSER(parse_if) { diff --git a/test/tables.tm b/test/tables.tm index 53b720c9..27b176bf 100644 --- a/test/tables.tm +++ b/test/tables.tm @@ -49,3 +49,8 @@ for k,v in t2 t2_str ++= "({k}=>{v})" >> t2_str = "(three=>3)" + +>> {i=>10*i for i in 5} += {1=>10, 2=>20, 3=>30, 4=>40, 5=>50} +>> {i=>10*i for i in 5 if i mod 2 != 0} += {1=>10, 3=>30, 5=>50} diff --git a/typecheck.c b/typecheck.c index 33e8bfee..228029b8 100644 --- a/typecheck.c +++ b/typecheck.c @@ -308,9 +308,11 @@ type_t *get_type(env_t *env, ast_t *ast) } else if (array->items) { for (ast_list_t *item = array->items; item; item = item->next) { type_t *t2; - if (item->ast->tag == For) { - env_t *scope = for_scope(env, item->ast); - t2 = get_type(scope, Match(item->ast, For)->body); + if (item->ast->tag == Comprehension) { + auto comp = Match(item->ast, Comprehension); + + env_t *scope = for_scope(env, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value)); + t2 = get_type(scope, comp->expr); } else { t2 = get_type(env, item->ast); } @@ -338,19 +340,30 @@ type_t *get_type(env_t *env, ast_t *ast) if (table->default_value) value_type = get_type(env, table->default_value); for (ast_list_t *entry = table->entries; entry; entry = entry->next) { - auto table_entry = Match(entry->ast, TableEntry); - type_t *key_t = get_type(env, table_entry->key); + type_t *key_t, *value_t; + if (entry->ast->tag == Comprehension) { + auto comp = Match(entry->ast, Comprehension); + env_t *scope = for_scope(env, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value)); + if (comp->expr->tag != TableEntry) + code_err(comp->expr, "I expected this table comprehension to have a key/value entry"); + key_t = get_type(scope, Match(comp->expr, TableEntry)->key); + value_t = get_type(scope, Match(comp->expr, TableEntry)->value); + } else { + auto e = Match(entry->ast, TableEntry); + key_t = get_type(env, e->key); + value_t = get_type(env, e->value); + } + type_t *key_merged = key_type ? type_or_type(key_type, key_t) : key_t; if (!key_merged) - code_err(table_entry->key, + code_err(entry->ast, "This table entry has type %T, which is different from earlier table entries which have type %T", key_t, key_type); key_type = key_merged; - type_t *value_t = get_type(env, table_entry->value); type_t *val_merged = value_type ? type_or_type(value_type, value_t) : value_t; if (!val_merged) - code_err(table_entry->value, + code_err(entry->ast, "This table entry has type %T, which is different from earlier table entries which have type %T", value_t, value_type); value_type = val_merged; @@ -361,7 +374,10 @@ type_t *get_type(env_t *env, ast_t *ast) return Type(TableType, .key_type=key_type, .value_type=value_type); } case TableEntry: { - code_err(ast, "This should not be typechecked directly"); + code_err(ast, "Table entries should not be typechecked directly"); + } + case Comprehension: { + code_err(ast, "Comprehensions should not be typechecked directly"); } case FieldAccess: { auto access = Match(ast, FieldAccess); |
