diff options
| author | Bruce Hill <bruce@bruce-hill.com> | 2024-09-12 14:27:13 -0400 |
|---|---|---|
| committer | Bruce Hill <bruce@bruce-hill.com> | 2024-09-12 14:27:13 -0400 |
| commit | 46a2aa2ffc71820767f0cdaead84c26dc240c893 (patch) | |
| tree | da00733058f5b9b4ef03339d630076880845d9ad | |
| parent | 10795782c674df12fc70ea3aeeaa2f62158b6cbd (diff) | |
Fix up comprehensions so set comprehensions work and everything is a bit
more clean
| -rw-r--r-- | ast.h | 3 | ||||
| -rw-r--r-- | compile.c | 125 | ||||
| -rw-r--r-- | environment.h | 2 | ||||
| -rw-r--r-- | parse.c | 2 | ||||
| -rw-r--r-- | test/sets.tm | 3 | ||||
| -rw-r--r-- | typecheck.c | 5 |
6 files changed, 76 insertions, 64 deletions
@@ -303,7 +303,8 @@ struct ast_s { } LinkerDirective; struct { CORD code; - type_ast_t *type; + struct type_s *type; + type_ast_t *type_ast; } InlineCCode; } __data; }; @@ -16,6 +16,8 @@ #include "typecheck.h" #include "builtins/util.h" +typedef ast_t* (*comprehension_body_t)(ast_t*, ast_t*); + static CORD compile_to_pointer_depth(env_t *env, ast_t *ast, int64_t target_depth, bool needs_incref); static env_t *with_enum_scope(env_t *env, type_t *t); static CORD compile_math_method(env_t *env, binop_e op, ast_t *lhs, ast_t *rhs, type_t *required_type); @@ -1317,29 +1319,22 @@ CORD compile_statement(env_t *env, ast_t *ast) return CORD_all("{\n", compile_inline_block(env, ast), "}\n"); } case Comprehension: { + if (!env->comprehension_action) + code_err(ast, "I don't know what to do with this comprehension!"); auto comp = Match(ast, Comprehension); - assert(env->comprehension_var); if (comp->expr->tag == Comprehension) { // Nested comprehension ast_t *body = comp->filter ? WrapAST(ast, If, .condition=comp->filter, .body=comp->expr) : comp->expr; ast_t *loop = WrapAST(ast, For, .vars=comp->vars, .iter=comp->iter, .body=body); return compile_statement(env, loop); - } else if (comp->expr->tag == TableEntry) { // Table comprehension - auto e = Match(comp->expr, TableEntry); - ast_t *body = WrapAST(ast, MethodCall, .name="set", .self=FakeAST(StackReference, FakeAST(Var, env->comprehension_var)), - .args=new(arg_ast_t, .value=e->key, .next=new(arg_ast_t, .value=e->value))); - if (comp->filter) - body = WrapAST(body, If, .condition=comp->filter, .body=body); - ast_t *loop = WrapAST(ast, For, .vars=comp->vars, .iter=comp->iter, .body=body); - return compile_statement(env, loop); - } else { // Array or Set comprehension - // TODO: support set comprehensions - ast_t *body = WrapAST(comp->expr, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, env->comprehension_var)), - .args=new(arg_ast_t, .value=comp->expr)); - if (comp->filter) - body = WrapAST(body, If, .condition=comp->filter, .body=body); - ast_t *loop = WrapAST(ast, For, .vars=comp->vars, .iter=comp->iter, .body=body); - return compile_statement(env, loop); } + + // Array/Set/Table comprehension: + comprehension_body_t get_body = (void*)env->comprehension_action->fn; + ast_t *body = get_body(comp->expr, env->comprehension_action->userdata); + if (comp->filter) + body = WrapAST(comp->expr, If, .condition=comp->filter, .body=body); + ast_t *loop = WrapAST(ast, For, .vars=comp->vars, .iter=comp->iter, .body=body); + return compile_statement(env, loop); } case Extern: return CORD_EMPTY; case InlineCCode: return Match(ast, InlineCCode)->code; @@ -1765,6 +1760,23 @@ CORD compile_null(type_t *t) } } +static ast_t *add_to_table_comprehension(ast_t *entry, ast_t *subject) +{ + auto e = Match(entry, TableEntry); + return WrapAST(entry, MethodCall, .name="set", .self=subject, + .args=new(arg_ast_t, .value=e->key, .next=new(arg_ast_t, .value=e->value))); +} + +static ast_t *add_to_array_comprehension(ast_t *item, ast_t *subject) +{ + return WrapAST(item, MethodCall, .name="insert", .self=subject, .args=new(arg_ast_t, .value=item)); +} + +static ast_t *add_to_set_comprehension(ast_t *item, ast_t *subject) +{ + return WrapAST(item, MethodCall, .name="add", .self=subject, .args=new(arg_ast_t, .value=item)); +} + CORD compile(env_t *env, ast_t *ast) { switch (ast->tag) { @@ -2216,20 +2228,20 @@ CORD compile(env_t *env, ast_t *ast) { env_t *scope = item_type->tag == EnumType ? with_enum_scope(env, item_type) : fresh_scope(env); static int64_t comp_num = 1; - scope->comprehension_var = heap_strf("arr$%ld", comp_num++); - CORD code = CORD_all("({ Array_t ", scope->comprehension_var, " = {};"); - set_binding(scope, scope->comprehension_var, new(binding_t, .type=array_type, .code=scope->comprehension_var)); + const char *comprehension_name = heap_strf("arr$%ld", comp_num++); + ast_t *comprehension_var = FakeAST(InlineCCode, .code=CORD_all("&", comprehension_name), + .type=Type(PointerType, .pointed=array_type, .is_stack=true)); + Closure_t comp_action = {.fn=add_to_array_comprehension, .userdata=comprehension_var}; + scope->comprehension_action = &comp_action; + CORD code = CORD_all("({ Array_t ", comprehension_name, " = {};"); + // set_binding(scope, comprehension_name, new(binding_t, .type=array_type, .code=comprehension_name)); for (ast_list_t *item = array->items; item; item = item->next) { - if (item->ast->tag == Comprehension) { + if (item->ast->tag == Comprehension) code = CORD_all(code, "\n", compile_statement(scope, item->ast)); - } else { - CORD insert = compile_statement( - scope, WrapAST(item->ast, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, scope->comprehension_var)), - .args=new(arg_ast_t, .value=item->ast))); - code = CORD_all(code, "\n", insert); - } + else + code = CORD_all(code, compile_statement(env, add_to_array_comprehension(item->ast, comprehension_var))); } - code = CORD_all(code, " ", scope->comprehension_var, "; })"); + code = CORD_all(code, " ", comprehension_name, "; })"); return code; } } @@ -2295,27 +2307,25 @@ CORD compile(env_t *env, ast_t *ast) { static int64_t comp_num = 1; env_t *scope = fresh_scope(env); - scope->comprehension_var = heap_strf("table$%ld", comp_num++); + const char *comprehension_name = heap_strf("table$%ld", comp_num++); + ast_t *comprehension_var = FakeAST(InlineCCode, .code=CORD_all("&", comprehension_name), + .type=Type(PointerType, .pointed=table_type, .is_stack=true)); - CORD code = CORD_all("({ Table_t ", scope->comprehension_var, " = {"); + CORD code = CORD_all("({ Table_t ", comprehension_name, " = {"); if (table->fallback) code = CORD_all(code, ".fallback=heap(", compile(env, table->fallback), "), "); code = CORD_cat(code, "};"); - set_binding(scope, scope->comprehension_var, new(binding_t, .type=table_type, .code=scope->comprehension_var)); + Closure_t comp_action = {.fn=add_to_table_comprehension, .userdata=comprehension_var}; + scope->comprehension_action = &comp_action; for (ast_list_t *entry = table->entries; entry; entry = entry->next) { - if (entry->ast->tag == Comprehension) { + if (entry->ast->tag == Comprehension) code = CORD_all(code, "\n", compile_statement(scope, entry->ast)); - } else { - auto e = Match(entry->ast, TableEntry); - CORD set = compile_statement( - scope, WrapAST(entry->ast, MethodCall, .name="set", .self=FakeAST(StackReference, FakeAST(Var, scope->comprehension_var)), - .args=new(arg_ast_t, .value=e->key, .next=new(arg_ast_t, .value=e->value)))); - code = CORD_all(code, "\n", set); - } + else + code = CORD_all(code, compile_statement(env, add_to_table_comprehension(entry->ast, comprehension_var))); } - code = CORD_all(code, " ", scope->comprehension_var, "; })"); + code = CORD_all(code, " ", comprehension_name, "; })"); return code; } @@ -2328,7 +2338,9 @@ CORD compile(env_t *env, ast_t *ast) type_t *set_type = get_type(env, ast); type_t *item_type = Match(set_type, SetType)->item_type; + size_t n = 0; for (ast_list_t *item = set->items; item; item = item->next) { + ++n; if (item->ast->tag == Comprehension) goto set_comprehension; } @@ -2337,12 +2349,7 @@ CORD compile(env_t *env, ast_t *ast) CORD code = CORD_all("Set(", compile_type(item_type), ", ", compile_type_info(env, item_type)); - - size_t n = 0; - for (ast_list_t *item = set->items; item; item = item->next) - ++n; CORD_appendf(&code, ", %zu", n); - env_t *scope = item_type->tag == EnumType ? with_enum_scope(env, item_type) : env; for (ast_list_t *item = set->items; item; item = item->next) { code = CORD_all(code, ", ", compile_to_type(scope, item->ast, item_type)); @@ -2354,21 +2361,19 @@ CORD compile(env_t *env, ast_t *ast) { static int64_t comp_num = 1; env_t *scope = item_type->tag == EnumType ? with_enum_scope(env, item_type) : fresh_scope(env); - scope->comprehension_var = heap_strf("set$%ld", comp_num++); - - CORD code = CORD_all("({ Table_t ", scope->comprehension_var, " = {};"); - set_binding(scope, scope->comprehension_var, new(binding_t, .type=set_type, .code=scope->comprehension_var)); + const char *comprehension_name = heap_strf("set$%ld", comp_num++); + ast_t *comprehension_var = FakeAST(InlineCCode, .code=CORD_all("&", comprehension_name), + .type=Type(PointerType, .pointed=set_type, .is_stack=true)); + CORD code = CORD_all("({ Table_t ", comprehension_name, " = {};"); + Closure_t comp_action = {.fn=add_to_set_comprehension, .userdata=comprehension_var}; + scope->comprehension_action = &comp_action; for (ast_list_t *item = set->items; item; item = item->next) { - if (item->ast->tag == Comprehension) { + if (item->ast->tag == Comprehension) code = CORD_all(code, "\n", compile_statement(scope, item->ast)); - } else { - CORD add_item = compile_statement( - scope, WrapAST(item->ast, MethodCall, .name="add", .self=FakeAST(StackReference, FakeAST(Var, scope->comprehension_var)), - .args=new(arg_ast_t, .value=item->ast))); - code = CORD_all(code, "\n", add_item); - } + else + code = CORD_all(code, compile_statement(env, add_to_set_comprehension(item->ast, comprehension_var))); } - code = CORD_all(code, " ", scope->comprehension_var, "; })"); + code = CORD_all(code, " ", comprehension_name, "; })"); return code; } @@ -2550,7 +2555,7 @@ CORD compile(env_t *env, ast_t *ast) ast_t *default_cmp = FakeAST(InlineCCode, .code=CORD_all("((Closure_t){.fn=generic_compare, .userdata=(void*)", compile_type_info(env, item_t), "})"), - .type=NewTypeAST(NULL, NULL, NULL, FunctionTypeAST)); + .type=Type(ClosureType, .fn=fn_t)); arg_t *arg_spec = new(arg_t, .name="item", .type=item_t, .next=new(arg_t, .name="by", .type=Type(ClosureType, .fn=fn_t), .default_val=default_cmp)); CORD arg_code = compile_arguments(env, ast, arg_spec, call->args); @@ -2563,7 +2568,7 @@ CORD compile(env_t *env, ast_t *ast) ast_t *default_cmp = FakeAST(InlineCCode, .code=CORD_all("((Closure_t){.fn=generic_compare, .userdata=(void*)", compile_type_info(env, item_t), "})"), - .type=NewTypeAST(NULL, NULL, NULL, FunctionTypeAST)); + .type=Type(ClosureType, .fn=fn_t)); arg_t *arg_spec = new(arg_t, .name="by", .type=Type(ClosureType, .fn=fn_t), .default_val=default_cmp); CORD arg_code = compile_arguments(env, ast, arg_spec, call->args); return CORD_all("Array$heap_pop_value(", self, ", ", arg_code, ", ", padded_item_size, ", ", compile_type(item_t), ")"); @@ -2575,7 +2580,7 @@ CORD compile(env_t *env, ast_t *ast) ast_t *default_cmp = FakeAST(InlineCCode, .code=CORD_all("((Closure_t){.fn=generic_compare, .userdata=(void*)", compile_type_info(env, item_t), "})"), - .type=NewTypeAST(NULL, NULL, NULL, FunctionTypeAST)); + .type=Type(ClosureType, .fn=fn_t)); arg_t *arg_spec = new(arg_t, .name="target", .type=item_t, .next=new(arg_t, .name="by", .type=Type(ClosureType, .fn=fn_t), .default_val=default_cmp)); CORD arg_code = compile_arguments(env, ast, arg_spec, call->args); diff --git a/environment.h b/environment.h index 1e13d5ab..2dfd4bfb 100644 --- a/environment.h +++ b/environment.h @@ -53,7 +53,7 @@ typedef struct env_s { deferral_t *deferred; CORD *libname; // Pointer to currently compiling library name (if any) namespace_t *namespace; - const char *comprehension_var; + Closure_t *comprehension_action; } env_t; typedef struct { @@ -2283,7 +2283,7 @@ PARSER(parse_inline_c) { parser_err(ctx, start, pos, "This inline C needs to have a type after it"); type = expect(ctx, start, &pos, parse_type, "I couldn't parse the type for this extern"); } - return NewAST(ctx->file, start, pos, InlineCCode, .code=c_code, .type=type); + return NewAST(ctx->file, start, pos, InlineCCode, .code=c_code, .type_ast=type); } PARSER(parse_doctest) { diff --git a/test/sets.tm b/test/sets.tm index bfec068e..a6a9d57b 100644 --- a/test/sets.tm +++ b/test/sets.tm @@ -34,3 +34,6 @@ func main(): >> t1:remove_all(t2) >> t1 = {10, 20} + + >> {3, i for i in 5} + = {3, 1, 2, 4, 5} diff --git a/typecheck.c b/typecheck.c index 192a335b..4952fbd7 100644 --- a/typecheck.c +++ b/typecheck.c @@ -1250,7 +1250,10 @@ type_t *get_type(env_t *env, ast_t *ast) case While: case For: return Type(VoidType); case InlineCCode: { - type_ast_t *type_ast = Match(ast, InlineCCode)->type; + auto inline_code = Match(ast, InlineCCode); + if (inline_code->type) + return inline_code->type; + type_ast_t *type_ast = inline_code->type_ast; return type_ast ? parse_type_ast(env, type_ast) : Type(VoidType); } case Unknown: code_err(ast, "I can't figure out the type of: %W", ast); |
