Improve comprehensions for both arrays and tables

This commit is contained in:
Bruce Hill 2024-03-17 14:46:36 -04:00
parent 5c2bb00baf
commit 9932841530
7 changed files with 138 additions and 73 deletions

2
ast.c
View File

@ -116,6 +116,8 @@ CORD ast_to_cord(ast_t *ast)
ast_to_cord(data.fallback), ast_to_cord(data.default_value),
ast_list_to_cord(data.entries))
T(TableEntry, "(%r => %r)", ast_to_cord(data.key), ast_to_cord(data.value))
T(Comprehension, "(expr=%r, key=%r, value=%r, iter=%r, filter=%r)", ast_to_cord(data.expr),
ast_to_cord(data.key), ast_to_cord(data.value), ast_to_cord(data.iter), ast_to_cord(data.filter))
T(FunctionDef, "(name=%r, args=%r, ret=%r, body=%r)", ast_to_cord(data.name),
arg_list_to_cord(data.args), type_ast_to_cord(data.ret_type), ast_to_cord(data.body))
T(Lambda, "(args=%r, body=%r)", arg_list_to_cord(data.args), ast_to_cord(data.body))

5
ast.h
View File

@ -96,7 +96,7 @@ typedef enum {
BinaryOp, UpdateAssign,
Length, Not, Negative, HeapAllocate, StackReference,
Min, Max,
Array, Table, TableEntry,
Array, Table, TableEntry, Comprehension,
FunctionDef, Lambda,
FunctionCall, MethodCall,
Block,
@ -173,6 +173,9 @@ struct ast_s {
struct {
ast_t *key, *value;
} TableEntry;
struct {
ast_t *expr, *key, *value, *iter, *filter;
} Comprehension;
struct {
ast_t *name;
arg_ast_t *args;

View File

@ -43,6 +43,8 @@ void *Table_get_raw(table_t t, const void *key, const TypeInfo *type);
void *Table_entry(table_t t, int64_t n);
void *Table_reserve(table_t *t, const void *key, const void *value, const TypeInfo *type);
void Table_set(table_t *t, const void *key, const void *value, const TypeInfo *type);
#define Table_set_value(t, key_expr, value_expr, type) ({ __typeof(key_expr) $k = key_expr; __typeof(value_expr) $v = value_expr; \
Table_set(t, &$k, &$v, type); })
void Table_remove(table_t *t, const void *key, const TypeInfo *type);
void Table_clear(table_t *t);
void Table_mark_copy_on_write(table_t *t);

117
compile.c
View File

@ -1153,7 +1153,7 @@ CORD compile(env_t *env, ast_t *ast)
int64_t n = 0;
for (ast_list_t *item = array->items; item; item = item->next) {
++n;
if (item->ast->tag == For)
if (item->ast->tag == Comprehension)
goto array_comprehension;
}
@ -1171,13 +1171,14 @@ CORD compile(env_t *env, ast_t *ast)
env_t *scope = fresh_scope(env);
set_binding(scope, "$arr", new(binding_t, .type=array_type, .code="$arr"));
for (ast_list_t *item = array->items; item; item = item->next) {
if (item->ast->tag == For) {
auto for_ = Match(item->ast, For);
env_t *body_scope = for_scope(scope, item->ast);
ast_t *for2 = WrapAST(item->ast, For, .index=for_->index, .value=for_->value, .iter=for_->iter,
.body=WrapAST(for_->body, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, "$arr")),
.args=new(arg_ast_t, .value=for_->body)));
code = CORD_all(code, "\n", compile_statement(body_scope, for2));
if (item->ast->tag == Comprehension) {
auto comp = Match(item->ast, Comprehension);
ast_t *body = WrapAST(comp->expr, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, "$arr")),
.args=new(arg_ast_t, .value=comp->expr));
if (comp->filter)
body = WrapAST(body, If, .condition=comp->filter, .body=body);
ast_t *loop = WrapAST(item->ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body);
code = CORD_all(code, "\n", compile_statement(scope, loop));
} else {
CORD insert = compile_statement(
scope, WrapAST(item->ast, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, "$arr")),
@ -1199,36 +1200,84 @@ CORD compile(env_t *env, ast_t *ast)
code = CORD_all(code, ".default_value=$heap(", compile(env, table->default_value),"),");
return CORD_cat(code, "}");
}
type_t *table_type = get_type(env, ast);
type_t *key_t = Match(table_type, TableType)->key_type;
type_t *value_t = Match(table_type, TableType)->value_type;
CORD code = CORD_all("$Table(",
compile_type(key_t), ", ",
compile_type(value_t), ", ",
compile_type_info(env, key_t), ", ",
compile_type_info(env, value_t));
if (table->fallback)
code = CORD_all(code, ", /*fallback:*/ $heap(", compile(env, table->fallback), ")");
else
code = CORD_all(code, ", /*fallback:*/ NULL");
if (table->default_value)
code = CORD_all(code, ", /*default:*/ $heap(", compile(env, table->default_value), ")");
else
code = CORD_all(code, ", /*default:*/ NULL");
size_t n = 0;
for (ast_list_t *entry = table->entries; entry; entry = entry->next)
++n;
CORD_appendf(&code, ", %zu", n);
for (ast_list_t *entry = table->entries; entry; entry = entry->next) {
auto e = Match(entry->ast, TableEntry);
code = CORD_all(code, ",\n\t{", compile(env, e->key), ", ", compile(env, e->value), "}");
if (entry->ast->tag == Comprehension)
goto table_comprehension;
}
return CORD_cat(code, ")");
{ // No comprehension:
CORD code = CORD_all("$Table(",
compile_type(key_t), ", ",
compile_type(value_t), ", ",
compile_type_info(env, key_t), ", ",
compile_type_info(env, value_t));
if (table->fallback)
code = CORD_all(code, ", /*fallback:*/ $heap(", compile(env, table->fallback), ")");
else
code = CORD_all(code, ", /*fallback:*/ NULL");
if (table->default_value)
code = CORD_all(code, ", /*default:*/ $heap(", compile(env, table->default_value), ")");
else
code = CORD_all(code, ", /*default:*/ NULL");
size_t n = 0;
for (ast_list_t *entry = table->entries; entry; entry = entry->next)
++n;
CORD_appendf(&code, ", %zu", n);
for (ast_list_t *entry = table->entries; entry; entry = entry->next) {
auto e = Match(entry->ast, TableEntry);
code = CORD_all(code, ",\n\t{", compile(env, e->key), ", ", compile(env, e->value), "}");
}
return CORD_cat(code, ")");
}
table_comprehension:
{
CORD code = "({ table_t $t = {";
if (table->fallback)
code = CORD_all(code, ".fallback=$heap(", compile(env, table->fallback), "), ");
if (table->default_value)
code = CORD_all(code, ".default_value=$heap(", compile(env, table->default_value), "), ");
code = CORD_cat(code, "};");
env_t *scope = fresh_scope(env);
set_binding(scope, "$t", new(binding_t, .type=table_type, .code="$t"));
for (ast_list_t *entry = table->entries; entry; entry = entry->next) {
if (entry->ast->tag == Comprehension) {
auto comp = Match(entry->ast, Comprehension);
auto e = Match(comp->expr, TableEntry);
ast_t *body = WrapAST(comp->expr, MethodCall, .name="set", .self=FakeAST(StackReference, FakeAST(Var, "$t")),
.args=new(arg_ast_t, .value=e->key, .next=new(arg_ast_t, .value=e->value)));
if (comp->filter)
body = WrapAST(body, If, .condition=comp->filter, .body=body);
ast_t *loop = WrapAST(entry->ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body);
code = CORD_all(code, "\n", compile_statement(scope, loop));
} else {
auto e = Match(entry->ast, TableEntry);
CORD set = compile_statement(
scope, WrapAST(entry->ast, MethodCall, .name="set", .self=FakeAST(StackReference, FakeAST(Var, "$arr")),
.args=new(arg_ast_t, .value=e->key, .next=new(arg_ast_t, .value=e->value))));
code = CORD_all(code, "\n", set);
}
}
code = CORD_cat(code, " $t; })");
return code;
}
}
case Comprehension: {
auto comp = Match(ast, Comprehension);
ast_t *collection = comp->expr->tag == TableEntry ? WrapAST(ast, Table, .entries=new(ast_list_t, .ast=ast))
: WrapAST(ast, Array, .items=new(ast_list_t, .ast=ast));
return compile(env, collection);
}
case Lambda: {
auto lambda = Match(ast, Lambda);
@ -1344,9 +1393,9 @@ CORD compile(env_t *env, ast_t *ast)
compile_type_info(env, self_value_t), ")");
} else if (streq(call->name, "set")) {
CORD self = compile_to_pointer_depth(env, call->self, 1, false);
arg_t *arg_spec = new(arg_t, .name="key", .type=Type(PointerType, .pointed=table->key_type, .is_stack=true, .is_readonly=true),
.next=new(arg_t, .name="value", .type=Type(PointerType, .pointed=table->value_type, .is_stack=true, .is_readonly=true)));
return CORD_all("Table_set(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ",
arg_t *arg_spec = new(arg_t, .name="key", .type=table->key_type,
.next=new(arg_t, .name="value", .type=table->value_type));
return CORD_all("Table_set_value(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ",
compile_type_info(env, self_value_t), ")");
} else if (streq(call->name, "remove")) {
CORD self = compile_to_pointer_depth(env, call->self, 1, false);

46
parse.c
View File

@ -68,7 +68,7 @@ static ast_t *parse_fncall_suffix(parse_ctx_t *ctx, ast_t *fn, bool is_extern);
static ast_t *parse_method_call_suffix(parse_ctx_t *ctx, ast_t *self);
static ast_t *parse_field_suffix(parse_ctx_t *ctx, ast_t *lhs);
static ast_t *parse_index_suffix(parse_ctx_t *ctx, ast_t *lhs);
static ast_t *parse_for_suffix(parse_ctx_t *ctx, ast_t *lhs);
static ast_t *parse_comprehension_suffix(parse_ctx_t *ctx, ast_t *lhs);
static arg_ast_t *parse_args(parse_ctx_t *ctx, const char **pos, bool allow_unnamed);
static PARSER(parse_for);
static PARSER(parse_while);
@ -617,7 +617,7 @@ PARSER(parse_array) {
for (;;) {
ast_t *item = optional(ctx, &pos, parse_extended_expr);
if (!item) break;
ast_t *suffixed = parse_for_suffix(ctx, item);
ast_t *suffixed = parse_comprehension_suffix(ctx, item);
if (suffixed) {
item = suffixed;
pos = suffixed->end;
@ -661,20 +661,12 @@ PARSER(parse_table) {
whitespace(&pos);
if (!match(&pos, "=>")) return NULL;
ast_t *value = expect(ctx, pos-1, &pos, parse_expr, "I couldn't parse the value for this table entry");
ast_t *entry = NewAST(ctx->file, entry_start, pos, TableEntry, .key=key, .value=value);
for (bool progress = true; progress; ) {
ast_t *new_entry;
progress = (false
|| (new_entry=parse_index_suffix(ctx, entry))
|| (new_entry=parse_field_suffix(ctx, entry))
|| (new_entry=parse_method_call_suffix(ctx, entry))
|| (new_entry=parse_fncall_suffix(ctx, entry, NORMAL_FUNCTION))
);
if (progress) entry = new_entry;
ast_t *suffixed = parse_comprehension_suffix(ctx, entry);
if (suffixed) {
entry = suffixed;
pos = suffixed->end;
}
pos = entry->end;
entries = new(ast_list_t, .ast=entry, .next=entries);
if (!match_separator(&pos))
break;
@ -791,34 +783,30 @@ ast_t *parse_index_suffix(parse_ctx_t *ctx, ast_t *lhs) {
return NewAST(ctx->file, start, pos, Index, .indexed=lhs, .index=index, .unchecked=unchecked);
}
ast_t *parse_for_suffix(parse_ctx_t *ctx, ast_t *lhs) {
ast_t *parse_comprehension_suffix(parse_ctx_t *ctx, ast_t *expr) {
// <expr> for [<index>,]<var> in <iter> [if <condition>]
if (!lhs) return NULL;
const char *start = lhs->start;
const char *pos = lhs->end;
if (!expr) return NULL;
const char *start = expr->start;
const char *pos = expr->end;
whitespace(&pos);
if (!match_word(&pos, "for")) return NULL;
ast_t *index = expect(ctx, start, &pos, parse_var, "I expected an iteration variable for this 'for'");
ast_t *key = expect(ctx, start, &pos, parse_var, "I expected an iteration variable for this 'for'");
whitespace(&pos);
ast_t *value = NULL;
if (match(&pos, ",")) {
value = expect(ctx, pos-1, &pos, parse_var, "I expected a variable after this comma");
} else {
value = index;
index = NULL;
value = key;
key = NULL;
}
expect_str(ctx, start, &pos, "in", "I expected an 'in' for this 'for'");
ast_t *iter = expect(ctx, start, &pos, parse_expr, "I expected an iterable value for this 'for'");
whitespace(&pos);
ast_t *body = lhs;
if (match_word(&pos, "if")) {
ast_t *condition = expect(ctx, pos-2, &pos, parse_expr, "I expected a condition for this 'if'");
body = NewAST(ctx->file, body->start, condition->end, Block,
.statements=new(ast_list_t, .ast=WrapAST(condition, If, .condition=FakeAST(Not, condition), .body=FakeAST(Skip)),
.next=new(ast_list_t, .ast=body)));
}
return NewAST(ctx->file, start, pos, For, .index=index, .value=value, .iter=iter, .body=body);
ast_t *filter = NULL;
if (match_word(&pos, "if"))
filter = expect(ctx, pos-2, &pos, parse_expr, "I expected a condition for this 'if'");
return NewAST(ctx->file, start, pos, Comprehension, .expr=expr, .key=key, .value=value, .iter=iter, .filter=filter);
}
PARSER(parse_if) {

View File

@ -49,3 +49,8 @@ for k,v in t2
t2_str ++= "({k}=>{v})"
>> t2_str
= "(three=>3)"
>> {i=>10*i for i in 5}
= {1=>10, 2=>20, 3=>30, 4=>40, 5=>50}
>> {i=>10*i for i in 5 if i mod 2 != 0}
= {1=>10, 3=>30, 5=>50}

View File

@ -308,9 +308,11 @@ type_t *get_type(env_t *env, ast_t *ast)
} else if (array->items) {
for (ast_list_t *item = array->items; item; item = item->next) {
type_t *t2;
if (item->ast->tag == For) {
env_t *scope = for_scope(env, item->ast);
t2 = get_type(scope, Match(item->ast, For)->body);
if (item->ast->tag == Comprehension) {
auto comp = Match(item->ast, Comprehension);
env_t *scope = for_scope(env, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value));
t2 = get_type(scope, comp->expr);
} else {
t2 = get_type(env, item->ast);
}
@ -338,19 +340,30 @@ type_t *get_type(env_t *env, ast_t *ast)
if (table->default_value)
value_type = get_type(env, table->default_value);
for (ast_list_t *entry = table->entries; entry; entry = entry->next) {
auto table_entry = Match(entry->ast, TableEntry);
type_t *key_t = get_type(env, table_entry->key);
type_t *key_t, *value_t;
if (entry->ast->tag == Comprehension) {
auto comp = Match(entry->ast, Comprehension);
env_t *scope = for_scope(env, FakeAST(For, .iter=comp->iter, .index=comp->key, .value=comp->value));
if (comp->expr->tag != TableEntry)
code_err(comp->expr, "I expected this table comprehension to have a key/value entry");
key_t = get_type(scope, Match(comp->expr, TableEntry)->key);
value_t = get_type(scope, Match(comp->expr, TableEntry)->value);
} else {
auto e = Match(entry->ast, TableEntry);
key_t = get_type(env, e->key);
value_t = get_type(env, e->value);
}
type_t *key_merged = key_type ? type_or_type(key_type, key_t) : key_t;
if (!key_merged)
code_err(table_entry->key,
code_err(entry->ast,
"This table entry has type %T, which is different from earlier table entries which have type %T",
key_t, key_type);
key_type = key_merged;
type_t *value_t = get_type(env, table_entry->value);
type_t *val_merged = value_type ? type_or_type(value_type, value_t) : value_t;
if (!val_merged)
code_err(table_entry->value,
code_err(entry->ast,
"This table entry has type %T, which is different from earlier table entries which have type %T",
value_t, value_type);
value_type = val_merged;
@ -361,7 +374,10 @@ type_t *get_type(env_t *env, ast_t *ast)
return Type(TableType, .key_type=key_type, .value_type=value_type);
}
case TableEntry: {
code_err(ast, "This should not be typechecked directly");
code_err(ast, "Table entries should not be typechecked directly");
}
case Comprehension: {
code_err(ast, "Comprehensions should not be typechecked directly");
}
case FieldAccess: {
auto access = Match(ast, FieldAccess);