aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2024-03-17 14:46:36 -0400
committerBruce Hill <bruce@bruce-hill.com>2024-03-17 14:46:36 -0400
commit993284153011006a1164b4b1f6bb1522e5131cb0 (patch)
treeb34793189363b6b314d2e69f2c7b1ea6249576c2
parent5c2bb00bafa4ad6e004f171687b9f21a824695c6 (diff)
Improve comprehensions for both arrays and tables
-rw-r--r--ast.c2
-rw-r--r--ast.h5
-rw-r--r--builtins/table.h2
-rw-r--r--compile.c115
-rw-r--r--parse.c46
-rw-r--r--test/tables.tm5
-rw-r--r--typecheck.c34
7 files changed, 137 insertions, 72 deletions
diff --git a/ast.c b/ast.c
index 9a8bc3c4..59ba8db2 100644
--- a/ast.c
+++ b/ast.c
@@ -116,6 +116,8 @@ CORD ast_to_cord(ast_t *ast)
ast_to_cord(data.fallback), ast_to_cord(data.default_value),
ast_list_to_cord(data.entries))
T(TableEntry, "(%r => %r)", ast_to_cord(data.key), ast_to_cord(data.value))
+ T(Comprehension, "(expr=%r, key=%r, value=%r, iter=%r, filter=%r)", ast_to_cord(data.expr),
+ ast_to_cord(data.key), ast_to_cord(data.value), ast_to_cord(data.iter), ast_to_cord(data.filter))
T(FunctionDef, "(name=%r, args=%r, ret=%r, body=%r)", ast_to_cord(data.name),
arg_list_to_cord(data.args), type_ast_to_cord(data.ret_type), ast_to_cord(data.body))
T(Lambda, "(args=%r, body=%r)", arg_list_to_cord(data.args), ast_to_cord(data.body))
diff --git a/ast.h b/ast.h
index 3ee7f61d..4807c7b6 100644
--- a/ast.h
+++ b/ast.h
@@ -96,7 +96,7 @@ typedef enum {
BinaryOp, UpdateAssign,
Length, Not, Negative, HeapAllocate, StackReference,
Min, Max,
- Array, Table, TableEntry,
+ Array, Table, TableEntry, Comprehension,
FunctionDef, Lambda,
FunctionCall, MethodCall,
Block,
@@ -174,6 +174,9 @@ struct ast_s {
ast_t *key, *value;
} TableEntry;
struct {
+ ast_t *expr, *key, *value, *iter, *filter;
+ } Comprehension;
+ struct {
ast_t *name;
arg_ast_t *args;
type_ast_t *ret_type;
diff --git a/builtins/table.h b/builtins/table.h
index 49f2517e..6577782d 100644
--- a/builtins/table.h
+++ b/builtins/table.h
@@ -43,6 +43,8 @@ void *Table_get_raw(table_t t, const void *key, const TypeInfo *type);
void *Table_entry(table_t t, int64_t n);
void *Table_reserve(table_t *t, const void *key, const void *value, const TypeInfo *type);
void Table_set(table_t *t, const void *key, const void *value, const TypeInfo *type);
+#define Table_set_value(t, key_expr, value_expr, type) ({ __typeof(key_expr) $k = key_expr; __typeof(value_expr) $v = value_expr; \
+ Table_set(t, &$k, &$v, type); })
void Table_remove(table_t *t, const void *key, const TypeInfo *type);
void Table_clear(table_t *t);
void Table_mark_copy_on_write(table_t *t);
diff --git a/compile.c b/compile.c
index 8aad6ae2..1032f61e 100644
--- a/compile.c
+++ b/compile.c
@@ -1153,7 +1153,7 @@ CORD compile(env_t *env, ast_t *ast)
int64_t n = 0;
for (ast_list_t *item = array->items; item; item = item->next) {
++n;
- if (item->ast->tag == For)
+ if (item->ast->tag == Comprehension)
goto array_comprehension;
}
@@ -1171,13 +1171,14 @@ CORD compile(env_t *env, ast_t *ast)
env_t *scope = fresh_scope(env);
set_binding(scope, "$arr", new(binding_t, .type=array_type, .code="$arr"));
for (ast_list_t *item = array->items; item; item = item->next) {
- if (item->ast->tag == For) {
- auto for_ = Match(item->ast, For);
- env_t *body_scope = for_scope(scope, item->ast);
- ast_t *for2 = WrapAST(item->ast, For, .index=for_->index, .value=for_->value, .iter=for_->iter,
- .body=WrapAST(for_->body, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, "$arr")),
- .args=new(arg_ast_t, .value=for_->body)));
- code = CORD_all(code, "\n", compile_statement(body_scope, for2));
+ if (item->ast->tag == Comprehension) {
+ auto comp = Match(item->ast, Comprehension);
+ ast_t *body = WrapAST(comp->expr, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, "$arr")),
+ .args=new(arg_ast_t, .value=comp->expr));
+ if (comp->filter)
+ body = WrapAST(body, If, .condition=comp->filter, .body=body);
+ ast_t *loop = WrapAST(item->ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body);
+ code = CORD_all(code, "\n", compile_statement(scope, loop));
} else {
CORD insert = compile_statement(
scope, WrapAST(item->ast, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, "$arr")),
@@ -1199,37 +1200,85 @@ CORD compile(env_t *env, ast_t *ast)
code = CORD_all(code, ".default_value=$heap(", compile(env, table->default_value),"),");
return CORD_cat(code, "}");
}
-
+
type_t *table_type = get_type(env, ast);
type_t *key_t = Match(table_type, TableType)->key_type;
type_t *value_t = Match(table_type, TableType)->value_type;
- CORD code = CORD_all("$Table(",
- compile_type(key_t), ", ",
- compile_type(value_t), ", ",
- compile_type_info(env, key_t), ", ",
- compile_type_info(env, value_t));
- if (table->fallback)
- code = CORD_all(code, ", /*fallback:*/ $heap(", compile(env, table->fallback), ")");
- else
- code = CORD_all(code, ", /*fallback:*/ NULL");
- if (table->default_value)
- code = CORD_all(code, ", /*default:*/ $heap(", compile(env, table->default_value), ")");
- else
- code = CORD_all(code, ", /*default:*/ NULL");
+ for (ast_list_t *entry = table->entries; entry; entry = entry->next) {
+ if (entry->ast->tag == Comprehension)
+ goto table_comprehension;
+ }
+
+ { // No comprehension:
+ CORD code = CORD_all("$Table(",
+ compile_type(key_t), ", ",
+ compile_type(value_t), ", ",
+ compile_type_info(env, key_t), ", ",
+ compile_type_info(env, value_t));
+ if (table->fallback)
+ code = CORD_all(code, ", /*fallback:*/ $heap(", compile(env, table->fallback), ")");
+ else
+ code = CORD_all(code, ", /*fallback:*/ NULL");
- size_t n = 0;
- for (ast_list_t *entry = table->entries; entry; entry = entry->next)
- ++n;
- CORD_appendf(&code, ", %zu", n);
+ if (table->default_value)
+ code = CORD_all(code, ", /*default:*/ $heap(", compile(env, table->default_value), ")");
+ else
+ code = CORD_all(code, ", /*default:*/ NULL");
- for (ast_list_t *entry = table->entries; entry; entry = entry->next) {
- auto e = Match(entry->ast, TableEntry);
- code = CORD_all(code, ",\n\t{", compile(env, e->key), ", ", compile(env, e->value), "}");
+ size_t n = 0;
+ for (ast_list_t *entry = table->entries; entry; entry = entry->next)
+ ++n;
+ CORD_appendf(&code, ", %zu", n);
+
+ for (ast_list_t *entry = table->entries; entry; entry = entry->next) {
+ auto e = Match(entry->ast, TableEntry);
+ code = CORD_all(code, ",\n\t{", compile(env, e->key), ", ", compile(env, e->value), "}");
+ }
+ return CORD_cat(code, ")");
+ }
+
+ table_comprehension:
+ {
+ CORD code = "({ table_t $t = {";
+ if (table->fallback)
+ code = CORD_all(code, ".fallback=$heap(", compile(env, table->fallback), "), ");
+
+ if (table->default_value)
+ code = CORD_all(code, ".default_value=$heap(", compile(env, table->default_value), "), ");
+ code = CORD_cat(code, "};");
+
+ env_t *scope = fresh_scope(env);
+ set_binding(scope, "$t", new(binding_t, .type=table_type, .code="$t"));
+ for (ast_list_t *entry = table->entries; entry; entry = entry->next) {
+ if (entry->ast->tag == Comprehension) {
+ auto comp = Match(entry->ast, Comprehension);
+ auto e = Match(comp->expr, TableEntry);
+ ast_t *body = WrapAST(comp->expr, MethodCall, .name="set", .self=FakeAST(StackReference, FakeAST(Var, "$t")),
+ .args=new(arg_ast_t, .value=e->key, .next=new(arg_ast_t, .value=e->value)));
+ if (comp->filter)
+ body = WrapAST(body, If, .condition=comp->filter, .body=body);
+ ast_t *loop = WrapAST(entry->ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body);
+ code = CORD_all(code, "\n", compile_statement(scope, loop));
+ } else {
+ auto e = Match(entry->ast, TableEntry);
+ CORD set = compile_statement(
+ scope, WrapAST(entry->ast, MethodCall, .name="set", .self=FakeAST(StackReference, FakeAST(Var, "$arr")),
+ .args=new(arg_ast_t, .value=e->key, .next=new(arg_ast_t, .value=e->value))));
+ code = CORD_all(code, "\n", set);
+ }
+ }
+ code = CORD_cat(code, " $t; })");
+ return code;
}
- return CORD_cat(code, ")");
}
+ case Comprehension: {
+ auto comp = Match(ast, Comprehension);
+ ast_t *collection = comp->expr->tag == TableEntry ? WrapAST(ast, Table, .entries=new(ast_list_t, .ast=ast))
+ : WrapAST(ast, Array, .items=new(ast_list_t, .ast=ast));
+ return compile(env, collection);
+ }
case Lambda: {
auto lambda = Match(ast, Lambda);
static int64_t lambda_number = 1;
@@ -1344,9 +1393,9 @@ CORD compile(env_t *env, ast_t *ast)
compile_type_info(env, self_value_t), ")");
} else if (streq(call->name, "set")) {
CORD self = compile_to_pointer_depth(env, call->self, 1, false);
- arg_t *arg_spec = new(arg_t, .name="key", .type=Type(PointerType, .pointed=table->key_type, .is_stack=true, .is_readonly=true),
- .next=new(arg_t, .name="value", .type=Type(PointerType, .pointed=table->value_type, .is_stack=true, .is_readonly=true)));
- return CORD_all("Table_set(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ",
+ arg_t *arg_spec = new(arg_t, .name="key", .type=table->key_type,
+ .next=new(arg_t, .name="value", .type=table->value_type));
+ return CORD_all("Table_set_value(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ",
compile_type_info(env, self_value_t), ")");
} else if (streq(call->name, "remove")) {
CORD self = compile_to_pointer_depth(env, call->self, 1, false);
diff --git a/parse.c b/parse.c
index 676e09f2..dd993ed9 100644
--- a/parse.c
+++ b/parse.c
@@ -68,7 +68,7 @@ static ast_t *parse_fncall_suffix(parse_ctx_t *ctx, ast_t *fn, bool is_extern);
static ast_t *parse_method_call_suffix(parse_ctx_t *ctx, ast_t *self);
static ast_t *parse_field_suffix(parse_ctx_t *ctx, ast_t *lhs);
static ast_t *parse_index_suffix(parse_ctx_t *ctx, ast_t *lhs);
-static ast_t *parse_for_suffix(parse_ctx_t *ctx, ast_t *lhs);
+static ast_t *parse_comprehension_suffix(parse_ctx_t *ctx, ast_t *lhs);
static arg_ast_t *parse_args(parse_ctx_t *ctx, const char **pos, bool allow_unnamed);
static PARSER(parse_for);
static PARSER(parse_while);
@@ -617,7 +617,7 @@ PARSER(parse_array) {
for (;;) {
ast_t *item = optional(ctx, &pos, parse_extended_expr);
if (!item) break;
- ast_t *suffixed = parse_for_suffix(ctx, item);
+ ast_t *suffixed = parse_comprehension_suffix(ctx, item);
if (suffixed) {
item = suffixed;
pos = suffixed->end;
@@ -661,20 +661,12 @@ PARSER(parse_table) {
whitespace(&pos);
if (!match(&pos, "=>")) return NULL;
ast_t *value = expect(ctx, pos-1, &pos, parse_expr, "I couldn't parse the value for this table entry");
-
ast_t *entry = NewAST(ctx->file, entry_start, pos, TableEntry, .key=key, .value=value);
- for (bool progress = true; progress; ) {
- ast_t *new_entry;
- progress = (false
- || (new_entry=parse_index_suffix(ctx, entry))
- || (new_entry=parse_field_suffix(ctx, entry))
- || (new_entry=parse_method_call_suffix(ctx, entry))
- || (new_entry=parse_fncall_suffix(ctx, entry, NORMAL_FUNCTION))
- );
- if (progress) entry = new_entry;
+ ast_t *suffixed = parse_comprehension_suffix(ctx, entry);
+ if (suffixed) {
+ entry = suffixed;
+ pos = suffixed->end;
}
- pos = entry->end;
-
entries = new(ast_list_t, .ast=entry, .next=entries);
if (!match_separator(&pos))
break;
@@ -791,34 +783,30 @@ ast_t *parse_index_suffix(parse_ctx_t *ctx, ast_t *lhs) {
return NewAST(ctx->file, start, pos, Index, .indexed=lhs, .index=index, .unchecked=unchecked);
}
-ast_t *parse_for_suffix(parse_ctx_t *ctx, ast_t *lhs) {
+ast_t *parse_comprehension_suffix(parse_ctx_t *ctx, ast_t *expr) {
// <expr> for [<index>,]<var> in <iter> [if <condition>]
- if (!lhs) return NULL;
- const char *start = lhs->start;
- const char *pos = lhs->end;
+ if (!expr) return NULL;
+ const char *start = expr->start;
+ const char *pos = expr->end;
whitespace(&pos);
if (!match_word(&pos, "for")) return NULL;
- ast_t *index = expect(ctx, start, &pos, parse_var, "I expected an iteration variable for this 'for'");
+ ast_t *key = expect(ctx, start, &pos, parse_var, "I expected an iteration variable for this 'for'");
whitespace(&pos);
ast_t *value = NULL;
if (match(&pos, ",")) {
value = expect(ctx, pos-1, &pos, parse_var, "I expected a variable after this comma");
} else {
- value = index;
- index = NULL;
+ value = key;
+ key = NULL;
}
expect_str(ctx, start, &pos, "in", "I expected an 'in' for this 'for'");
ast_t *iter = expect(ctx, start, &pos, parse_expr, "I expected an iterable value for this 'for'");
whitespace(&pos);
- ast_t *body = lhs;
- if (match_word(&pos, "if")) {
- ast_t *condition = expect(ctx, pos-2, &pos, parse_expr, "I expected a condition for this 'if'");
- body = NewAST(ctx->file, body->start, condition->end, Block,
- .statements=new(ast_list_t, .ast=WrapAST(condition, If, .condition=FakeAST(Not, condition), .body=FakeAST(Skip)),
- .next=new(ast_list_t, .ast=body)));
- }
- return NewAST(ctx->file, start, pos, For, .index=index, .value=value, .iter=iter, .body=body);
+ ast_t *filter = NULL;
+ if (match_word(&pos, "if"))
+ filter = expect(ctx, pos-2, &pos, parse_expr, "I expected a condition for this 'if'");
+ return NewAST(ctx->file, start, pos, Comprehension, .expr=expr, .key=key, .value=value, .iter=iter, .filter=filter);
}
PARSER(parse_if) {
diff --git a/test/tables.tm b/test/tables.tm
index 53b720c9..27b176bf 100644
--- a/test/tables.tm
+++ b/test/tables.tm
@@ -49,3 +49,8 @@ for k,v in t2
t2_str ++= "({k}=>{v})"
>> t2_str
= "(three=>3)"
+
+>> {i=>10*i for i in 5}
+= {1=>10, 2=>20, 3=>30, 4=>40, 5=>50}
+>> {i=>10*i for i in 5 if i mod 2 != 0}
+= {1=>10, 3=>30, 5=>50}
diff --git a/typecheck.c b/typecheck.c
index 33e8bfee..228029b8 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -308,9 +308,11 @@ type_t *get_type(env_t *env, ast_t *ast)
} else if (array->items) {
for (ast_list_t *item = array->items; item; item = item->next) {
type_t *t2;
- if (item->ast->tag == For) {
- env_t *scope = for_scope(env, item->ast);
- t2 = get_type(scope, Match(item->ast, For)->body);
+ if (item->ast->tag == Comprehension) {
+ auto comp = Match(item->ast, Comprehension);
+
+ env_t *scope = for_scope(env, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value));
+ t2 = get_type(scope, comp->expr);
} else {
t2 = get_type(env, item->ast);
}
@@ -338,19 +340,30 @@ type_t *get_type(env_t *env, ast_t *ast)
if (table->default_value)
value_type = get_type(env, table->default_value);
for (ast_list_t *entry = table->entries; entry; entry = entry->next) {
- auto table_entry = Match(entry->ast, TableEntry);
- type_t *key_t = get_type(env, table_entry->key);
+ type_t *key_t, *value_t;
+ if (entry->ast->tag == Comprehension) {
+ auto comp = Match(entry->ast, Comprehension);
+ env_t *scope = for_scope(env, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value));
+ if (comp->expr->tag != TableEntry)
+ code_err(comp->expr, "I expected this table comprehension to have a key/value entry");
+ key_t = get_type(scope, Match(comp->expr, TableEntry)->key);
+ value_t = get_type(scope, Match(comp->expr, TableEntry)->value);
+ } else {
+ auto e = Match(entry->ast, TableEntry);
+ key_t = get_type(env, e->key);
+ value_t = get_type(env, e->value);
+ }
+
type_t *key_merged = key_type ? type_or_type(key_type, key_t) : key_t;
if (!key_merged)
- code_err(table_entry->key,
+ code_err(entry->ast,
"This table entry has type %T, which is different from earlier table entries which have type %T",
key_t, key_type);
key_type = key_merged;
- type_t *value_t = get_type(env, table_entry->value);
type_t *val_merged = value_type ? type_or_type(value_type, value_t) : value_t;
if (!val_merged)
- code_err(table_entry->value,
+ code_err(entry->ast,
"This table entry has type %T, which is different from earlier table entries which have type %T",
value_t, value_type);
value_type = val_merged;
@@ -361,7 +374,10 @@ type_t *get_type(env_t *env, ast_t *ast)
return Type(TableType, .key_type=key_type, .value_type=value_type);
}
case TableEntry: {
- code_err(ast, "This should not be typechecked directly");
+ code_err(ast, "Table entries should not be typechecked directly");
+ }
+ case Comprehension: {
+ code_err(ast, "Comprehensions should not be typechecked directly");
}
case FieldAccess: {
auto access = Match(ast, FieldAccess);