aboutsummaryrefslogtreecommitdiff
path: root/compile.c
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2024-03-14 02:37:56 -0400
committerBruce Hill <bruce@bruce-hill.com>2024-03-14 02:37:56 -0400
commitfdc3eadba25aff7894419e483519e73150be33d4 (patch)
treeae0bf68e1bfa501fd9010b66d2211b0b1ef59a23 /compile.c
parent130ddc8ea04060ec52d9a2fd03da8c9662d32f9c (diff)
Array comprehensions
Diffstat (limited to 'compile.c')
-rw-r--r--compile.c67
1 files changed, 40 insertions, 27 deletions
diff --git a/compile.c b/compile.c
index 11afe0de..19ff0563 100644
--- a/compile.c
+++ b/compile.c
@@ -699,15 +699,44 @@ CORD compile(env_t *env, ast_t *ast)
if (!array->items)
return "(array_t){.length=0}";
+ type_t *array_t = get_type(env, ast);
+
int64_t n = 0;
- for (ast_list_t *item = array->items; item; item = item->next)
+ for (ast_list_t *item = array->items; item; item = item->next) {
++n;
+ if (item->ast->tag == For)
+ goto array_comprehension;
+ }
- type_t *item_type = Match(get_type(env, ast), ArrayType)->item_type;
+ type_t *item_type = Match(array_t, ArrayType)->item_type;
CORD code = CORD_all("$TypedArrayN(", compile_type(item_type), CORD_asprintf(", %ld", n));
for (ast_list_t *item = array->items; item; item = item->next)
code = CORD_all(code, ", ", compile(env, item->ast));
return CORD_cat(code, ")");
+
+ array_comprehension:
+ {
+ CORD code = "({ array_t $arr = {};";
+ env_t *scope = fresh_scope(env);
+ set_binding(scope, "$arr", new(binding_t, .type=array_t, .code="$arr"));
+ for (ast_list_t *item = array->items; item; item = item->next) {
+ if (item->ast->tag == For) {
+ auto for_ = Match(item->ast, For);
+ env_t *body_scope = for_scope(scope, item->ast);
+ ast_t *for2 = WrapAST(item->ast, For, .index=for_->index, .value=for_->value, .iter=for_->iter,
+ .body=WrapAST(for_->body, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, "$arr")),
+ .args=new(arg_ast_t, .value=for_->body)));
+ code = CORD_all(code, "\n", compile_statement(body_scope, for2));
+ } else {
+ CORD insert = compile_statement(
+ scope, WrapAST(item->ast, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, "$arr")),
+ .args=new(arg_ast_t, .value=item->ast)));
+ code = CORD_all(code, "\n", insert);
+ }
+ }
+ code = CORD_cat(code, " $arr; })");
+ return code;
+ }
}
case Table: {
auto table = Match(ast, Table);
@@ -902,9 +931,9 @@ CORD compile(env_t *env, ast_t *ast)
if (streq(call->name, "insert")) {
type_t *item_t = Match(self_value_t, ArrayType)->item_type;
CORD self = compile_to_pointer_depth(env, call->self, 1, false);
- arg_t *arg_spec = new(arg_t, .name="item", .type=Type(PointerType, .pointed=item_t, .is_stack=true, .is_readonly=true),
+ arg_t *arg_spec = new(arg_t, .name="item", .type=item_t,
.next=new(arg_t, .name="at", .type=Type(IntType, .bits=64), .default_val=FakeAST(Int, .i=0, .bits=64)));
- return CORD_all("Array__insert(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ",
+ return CORD_all("Array__insert_value(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ",
compile_type_info(env, self_value_t), ")");
} else if (streq(call->name, "insert_all")) {
CORD self = compile_to_pointer_depth(env, call->self, 1, false);
@@ -1088,28 +1117,21 @@ CORD compile(env_t *env, ast_t *ast)
case For: {
auto for_ = Match(ast, For);
type_t *iter_t = get_type(env, for_->iter);
+ env_t *scope = for_scope(env, ast);
switch (iter_t->tag) {
case ArrayType: {
type_t *item_t = Match(iter_t, ArrayType)->item_type;
- env_t *scope = fresh_scope(env);
CORD index = for_->index ? compile(env, for_->index) : "$i";
- if (for_->index)
- set_binding(scope, CORD_to_const_char_star(index), new(binding_t, .type=Type(IntType, .bits=64)));
CORD value = compile(env, for_->value);
- set_binding(scope, CORD_to_const_char_star(value), new(binding_t, .type=item_t));
return CORD_all("$ARRAY_FOREACH(", compile(env, for_->iter), ", ", index, ", ", compile_type(item_t), ", ", value, ", ",
compile(scope, for_->body), ", ", for_->empty ? compile(env, for_->empty) : "{}", ")");
}
case TableType: {
type_t *key_t = Match(iter_t, TableType)->key_type;
type_t *value_t = Match(iter_t, TableType)->value_type;
- env_t *scope = fresh_scope(env);
- CORD key, value;
if (for_->index) {
- key = compile(env, for_->index);
- value = compile(env, for_->value);
- set_binding(scope, CORD_to_const_char_star(key), new(binding_t, .type=key_t));
- set_binding(scope, CORD_to_const_char_star(value), new(binding_t, .type=value_t));
+ CORD key = compile(env, for_->index);
+ CORD value = compile(env, for_->value);
size_t value_offset = type_size(key_t);
if (type_align(value_t) > 1 && value_offset % type_align(value_t))
@@ -1118,25 +1140,16 @@ CORD compile(env_t *env, ast_t *ast)
compile_type(value_t), ", ", value, ", ", heap_strf("%zu", value_offset),
", ", compile(scope, for_->body), ", ", for_->empty ? compile(env, for_->empty) : "{}", ")");
} else {
- key = compile(env, for_->value);
- set_binding(scope, CORD_to_const_char_star(key), new(binding_t, .type=key_t));
+ CORD key = compile(env, for_->value);
return CORD_all("$ARRAY_FOREACH((", compile(env, for_->iter), ").entries, $i, ", compile_type(key_t), ", ", key, ", ",
compile(scope, for_->body), ", ", for_->empty ? compile(env, for_->empty) : "{}", ")");
}
}
case IntType: {
- type_t *item_t = iter_t;
env_t *scope = fresh_scope(env);
CORD value = compile(env, for_->value);
- set_binding(scope, CORD_to_const_char_star(value), new(binding_t, .type=item_t, .code=value));
-
CORD n = compile(env, for_->iter);
- CORD index = CORD_EMPTY;
- if (for_->index) {
- index = compile(env, for_->index);
- set_binding(scope, CORD_to_const_char_star(index), new(binding_t, .type=Type(IntType, .bits=64), .code=index));
- }
-
+ CORD index = for_->index ? compile(env, for_->index) : CORD_EMPTY;
if (for_->empty && index) {
return CORD_all(
"{\n"
@@ -1158,11 +1171,11 @@ CORD compile(env_t *env, ast_t *ast)
} else if (index) {
return CORD_all(
"for (int64_t ", value, ", ", index, " = 1, $n = ", n, "; (", value, "=", index,") <= $n; ++", value, ")\n"
- "\t", compile(scope, for_->body), "\n");
+ "\t", compile_statement(scope, for_->body), "\n");
} else {
return CORD_all(
"for (int64_t ", value, " = 1, $n = ", compile(env, for_->iter), "; ", value, " <= $n; ++", value, ")\n"
- "\t", compile(scope, for_->body), "\n");
+ "\t", compile_statement(scope, for_->body), "\n");
}
}
default: code_err(for_->iter, "Iteration is not implemented for type: %T", iter_t);