From fdc3eadba25aff7894419e483519e73150be33d4 Mon Sep 17 00:00:00 2001 From: Bruce Hill Date: Thu, 14 Mar 2024 02:37:56 -0400 Subject: [PATCH] Array comprehensions --- builtins/array.h | 1 + compile.c | 71 ++++++++++++++++++++++++++++-------------------- environment.c | 53 +++++++++++++++++++++++++++++++++++- environment.h | 1 + parse.c | 36 ++++++++++++++++++++++++ typecheck.c | 8 +++++- 6 files changed, 139 insertions(+), 31 deletions(-) diff --git a/builtins/array.h b/builtins/array.h index d36db57..c2c0b8b 100644 --- a/builtins/array.h +++ b/builtins/array.h @@ -46,6 +46,7 @@ $ARRAY_DECREF($arr); \ } +#define Array__insert_value(arr, item_expr, index, type) ({ __typeof(item_expr) $item = item_expr; Array__insert(arr, &$item, index, type); }) void Array__insert(array_t *arr, const void *item, int64_t index, const TypeInfo *type); void Array__insert_all(array_t *arr, array_t to_insert, int64_t index, const TypeInfo *type); void Array__remove(array_t *arr, int64_t index, int64_t count, const TypeInfo *type); diff --git a/compile.c b/compile.c index 11afe0d..19ff056 100644 --- a/compile.c +++ b/compile.c @@ -699,15 +699,44 @@ CORD compile(env_t *env, ast_t *ast) if (!array->items) return "(array_t){.length=0}"; - int64_t n = 0; - for (ast_list_t *item = array->items; item; item = item->next) - ++n; + type_t *array_t = get_type(env, ast); - type_t *item_type = Match(get_type(env, ast), ArrayType)->item_type; + int64_t n = 0; + for (ast_list_t *item = array->items; item; item = item->next) { + ++n; + if (item->ast->tag == For) + goto array_comprehension; + } + + type_t *item_type = Match(array_t, ArrayType)->item_type; CORD code = CORD_all("$TypedArrayN(", compile_type(item_type), CORD_asprintf(", %ld", n)); for (ast_list_t *item = array->items; item; item = item->next) code = CORD_all(code, ", ", compile(env, item->ast)); return CORD_cat(code, ")"); + + array_comprehension: + { + CORD code = "({ array_t $arr = {};"; + env_t *scope = fresh_scope(env); + set_binding(scope, "$arr", new(binding_t, .type=array_t, .code="$arr")); + for (ast_list_t *item = array->items; item; item = item->next) { + if (item->ast->tag == For) { + auto for_ = Match(item->ast, For); + env_t *body_scope = for_scope(scope, item->ast); + ast_t *for2 = WrapAST(item->ast, For, .index=for_->index, .value=for_->value, .iter=for_->iter, + .body=WrapAST(for_->body, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, "$arr")), + .args=new(arg_ast_t, .value=for_->body))); + code = CORD_all(code, "\n", compile_statement(body_scope, for2)); + } else { + CORD insert = compile_statement( + scope, WrapAST(item->ast, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, "$arr")), + .args=new(arg_ast_t, .value=item->ast))); + code = CORD_all(code, "\n", insert); + } + } + code = CORD_cat(code, " $arr; })"); + return code; + } } case Table: { auto table = Match(ast, Table); @@ -902,9 +931,9 @@ CORD compile(env_t *env, ast_t *ast) if (streq(call->name, "insert")) { type_t *item_t = Match(self_value_t, ArrayType)->item_type; CORD self = compile_to_pointer_depth(env, call->self, 1, false); - arg_t *arg_spec = new(arg_t, .name="item", .type=Type(PointerType, .pointed=item_t, .is_stack=true, .is_readonly=true), + arg_t *arg_spec = new(arg_t, .name="item", .type=item_t, .next=new(arg_t, .name="at", .type=Type(IntType, .bits=64), .default_val=FakeAST(Int, .i=0, .bits=64))); - return CORD_all("Array__insert(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ", + return CORD_all("Array__insert_value(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ", compile_type_info(env, self_value_t), ")"); } else if (streq(call->name, "insert_all")) { CORD self = compile_to_pointer_depth(env, call->self, 1, false); @@ -1088,28 +1117,21 @@ CORD compile(env_t *env, ast_t *ast) case For: { auto for_ = Match(ast, For); type_t *iter_t = get_type(env, for_->iter); + env_t *scope = for_scope(env, ast); switch (iter_t->tag) { case ArrayType: { type_t *item_t = Match(iter_t, ArrayType)->item_type; - env_t *scope = fresh_scope(env); CORD index = for_->index ? compile(env, for_->index) : "$i"; - if (for_->index) - set_binding(scope, CORD_to_const_char_star(index), new(binding_t, .type=Type(IntType, .bits=64))); CORD value = compile(env, for_->value); - set_binding(scope, CORD_to_const_char_star(value), new(binding_t, .type=item_t)); return CORD_all("$ARRAY_FOREACH(", compile(env, for_->iter), ", ", index, ", ", compile_type(item_t), ", ", value, ", ", compile(scope, for_->body), ", ", for_->empty ? compile(env, for_->empty) : "{}", ")"); } case TableType: { type_t *key_t = Match(iter_t, TableType)->key_type; type_t *value_t = Match(iter_t, TableType)->value_type; - env_t *scope = fresh_scope(env); - CORD key, value; if (for_->index) { - key = compile(env, for_->index); - value = compile(env, for_->value); - set_binding(scope, CORD_to_const_char_star(key), new(binding_t, .type=key_t)); - set_binding(scope, CORD_to_const_char_star(value), new(binding_t, .type=value_t)); + CORD key = compile(env, for_->index); + CORD value = compile(env, for_->value); size_t value_offset = type_size(key_t); if (type_align(value_t) > 1 && value_offset % type_align(value_t)) @@ -1118,25 +1140,16 @@ CORD compile(env_t *env, ast_t *ast) compile_type(value_t), ", ", value, ", ", heap_strf("%zu", value_offset), ", ", compile(scope, for_->body), ", ", for_->empty ? compile(env, for_->empty) : "{}", ")"); } else { - key = compile(env, for_->value); - set_binding(scope, CORD_to_const_char_star(key), new(binding_t, .type=key_t)); + CORD key = compile(env, for_->value); return CORD_all("$ARRAY_FOREACH((", compile(env, for_->iter), ").entries, $i, ", compile_type(key_t), ", ", key, ", ", compile(scope, for_->body), ", ", for_->empty ? compile(env, for_->empty) : "{}", ")"); } } case IntType: { - type_t *item_t = iter_t; env_t *scope = fresh_scope(env); CORD value = compile(env, for_->value); - set_binding(scope, CORD_to_const_char_star(value), new(binding_t, .type=item_t, .code=value)); - CORD n = compile(env, for_->iter); - CORD index = CORD_EMPTY; - if (for_->index) { - index = compile(env, for_->index); - set_binding(scope, CORD_to_const_char_star(index), new(binding_t, .type=Type(IntType, .bits=64), .code=index)); - } - + CORD index = for_->index ? compile(env, for_->index) : CORD_EMPTY; if (for_->empty && index) { return CORD_all( "{\n" @@ -1158,11 +1171,11 @@ CORD compile(env_t *env, ast_t *ast) } else if (index) { return CORD_all( "for (int64_t ", value, ", ", index, " = 1, $n = ", n, "; (", value, "=", index,") <= $n; ++", value, ")\n" - "\t", compile(scope, for_->body), "\n"); + "\t", compile_statement(scope, for_->body), "\n"); } else { return CORD_all( "for (int64_t ", value, " = 1, $n = ", compile(env, for_->iter), "; ", value, " <= $n; ++", value, ")\n" - "\t", compile(scope, for_->body), "\n"); + "\t", compile_statement(scope, for_->body), "\n"); } } default: code_err(for_->iter, "Iteration is not implemented for type: %T", iter_t); diff --git a/environment.c b/environment.c index 3ba7cdd..3a8ac50 100644 --- a/environment.c +++ b/environment.c @@ -201,6 +201,56 @@ env_t *fresh_scope(env_t *env) return scope; } +env_t *for_scope(env_t *env, ast_t *ast) +{ + auto for_ = Match(ast, For); + type_t *iter_t = get_type(env, for_->iter); + env_t *scope = fresh_scope(env); + const char *value = Match(for_->value, Var)->name; + if (for_->index) { + const char *index = Match(for_->index, Var)->name; + switch (iter_t->tag) { + case ArrayType: { + type_t *item_t = Match(iter_t, ArrayType)->item_type; + set_binding(scope, index, new(binding_t, .type=Type(IntType, .bits=64), .code=index)); + set_binding(scope, value, new(binding_t, .type=item_t, .code=value)); + return scope; + } + case TableType: { + type_t *key_t = Match(iter_t, TableType)->key_type; + type_t *value_t = Match(iter_t, TableType)->value_type; + set_binding(scope, index, new(binding_t, .type=key_t, .code=index)); + set_binding(scope, value, new(binding_t, .type=value_t, .code=value)); + return scope; + } + case IntType: { + set_binding(scope, index, new(binding_t, .type=Type(IntType, .bits=64), .code=index)); + set_binding(scope, value, new(binding_t, .type=iter_t, .code=value)); + return scope; + } + default: code_err(for_->iter, "Iteration is not implemented for type: %T", iter_t); + } + } else { + switch (iter_t->tag) { + case ArrayType: { + type_t *item_t = Match(iter_t, ArrayType)->item_type; + set_binding(scope, value, new(binding_t, .type=item_t, .code=value)); + return scope; + } + case TableType: { + type_t *key_t = Match(iter_t, TableType)->key_type; + set_binding(scope, value, new(binding_t, .type=key_t, .code=value)); + return scope; + } + case IntType: { + set_binding(scope, value, new(binding_t, .type=iter_t, .code=value)); + return scope; + } + default: code_err(for_->iter, "Iteration is not implemented for type: %T", iter_t); + } + } +} + env_t *namespace_env(env_t *env, const char *namespace_name) { env_t *ns_env = new(env_t); @@ -268,7 +318,8 @@ binding_t *get_namespace_binding(env_t *env, ast_t *self, const char *name) void set_binding(env_t *env, const char *name, binding_t *binding) { - Table_str_set(env->locals, name, binding); + if (name && binding) + Table_str_set(env->locals, name, binding); } void compiler_err(file_t *f, const char *start, const char *end, const char *fmt, ...) diff --git a/environment.h b/environment.h index 2ab7a9d..8abc889 100644 --- a/environment.h +++ b/environment.h @@ -38,6 +38,7 @@ typedef struct { env_t *new_compilation_unit(void); env_t *global_scope(env_t *env); env_t *fresh_scope(env_t *env); +env_t *for_scope(env_t *env, ast_t *ast); env_t *namespace_env(env_t *env, const char *namespace_name); __attribute__((noreturn)) void compiler_err(file_t *f, const char *start, const char *end, const char *fmt, ...); diff --git a/parse.c b/parse.c index ef0a89a..d44afda 100644 --- a/parse.c +++ b/parse.c @@ -68,6 +68,7 @@ static ast_t *parse_fncall_suffix(parse_ctx_t *ctx, ast_t *fn, bool is_extern); static ast_t *parse_method_call_suffix(parse_ctx_t *ctx, ast_t *self); static ast_t *parse_field_suffix(parse_ctx_t *ctx, ast_t *lhs); static ast_t *parse_index_suffix(parse_ctx_t *ctx, ast_t *lhs); +static ast_t *parse_for_suffix(parse_ctx_t *ctx, ast_t *lhs); static arg_ast_t *parse_args(parse_ctx_t *ctx, const char **pos, bool allow_unnamed); static PARSER(parse_for); static PARSER(parse_while); @@ -616,6 +617,11 @@ PARSER(parse_array) { for (;;) { ast_t *item = optional(ctx, &pos, parse_extended_expr); if (!item) break; + ast_t *suffixed = parse_for_suffix(ctx, item); + if (suffixed) { + item = suffixed; + pos = suffixed->end; + } items = new(ast_list_t, .ast=item, .next=items); if (!match_separator(&pos)) break; @@ -785,6 +791,36 @@ ast_t *parse_index_suffix(parse_ctx_t *ctx, ast_t *lhs) { return NewAST(ctx->file, start, pos, Index, .indexed=lhs, .index=index, .unchecked=unchecked); } +ast_t *parse_for_suffix(parse_ctx_t *ctx, ast_t *lhs) { + // for [,] in [if ] + if (!lhs) return NULL; + const char *start = lhs->start; + const char *pos = lhs->end; + whitespace(&pos); + if (!match_word(&pos, "for")) return NULL; + + ast_t *index = expect(ctx, start, &pos, parse_var, "I expected an iteration variable for this 'for'"); + whitespace(&pos); + ast_t *value = NULL; + if (match(&pos, ",")) { + value = expect(ctx, pos-1, &pos, parse_var, "I expected a variable after this comma"); + } else { + value = index; + index = NULL; + } + expect_str(ctx, start, &pos, "in", "I expected an 'in' for this 'for'"); + ast_t *iter = expect(ctx, start, &pos, parse_expr, "I expected an iterable value for this 'for'"); + whitespace(&pos); + ast_t *body = lhs; + if (match_word(&pos, "if")) { + ast_t *condition = expect(ctx, pos-2, &pos, parse_expr, "I expected a condition for this 'if'"); + body = NewAST(ctx->file, body->start, condition->end, Block, + .statements=new(ast_list_t, .ast=WrapAST(condition, If, .condition=condition, .else_body=FakeAST(Skip)), + .next=new(ast_list_t, .ast=body))); + } + return NewAST(ctx->file, start, pos, For, .index=index, .value=value, .iter=iter, .body=body); +} + PARSER(parse_if) { // if [then] [else ] const char *start = pos; diff --git a/typecheck.c b/typecheck.c index 769d868..33e8bfe 100644 --- a/typecheck.c +++ b/typecheck.c @@ -307,7 +307,13 @@ type_t *get_type(env_t *env, ast_t *ast) item_type = parse_type_ast(env, array->type); } else if (array->items) { for (ast_list_t *item = array->items; item; item = item->next) { - type_t *t2 = get_type(env, item->ast); + type_t *t2; + if (item->ast->tag == For) { + env_t *scope = for_scope(env, item->ast); + t2 = get_type(scope, Match(item->ast, For)->body); + } else { + t2 = get_type(env, item->ast); + } type_t *merged = item_type ? type_or_type(item_type, t2) : t2; if (!merged) code_err(item->ast,