diff options
| -rw-r--r-- | compile.c | 8 | ||||
| -rw-r--r-- | typecheck.c | 25 | ||||
| -rw-r--r-- | types.c | 19 | ||||
| -rw-r--r-- | types.h | 2 |
4 files changed, 28 insertions, 26 deletions
@@ -2021,6 +2021,7 @@ CORD compile(env_t *env, ast_t *ast) CORD code = CORD_all( "({ // Reduction:\n", compile_declaration(t, "reduction"), ";\n" + "Bool_t is_first = yes;\n" ); env_t *scope = fresh_scope(env); ast_t *result = FakeAST(Var, "$reduction"); @@ -2040,11 +2041,10 @@ CORD compile(env_t *env, ast_t *ast) Text$quoted(ast->file->filename, false), (long)(reduction->iter->start - reduction->iter->file->text), (long)(reduction->iter->end - reduction->iter->file->text))); } - ast_t *i = FakeAST(Var, "$i"); ast_t *item = FakeAST(Var, "$iter_value"); - set_binding(scope, "$iter_value", new(binding_t, .type=t, .code="iter_value")); - ast_t *body = FakeAST(InlineCCode, CORD_all("reduction = $$i == 1 ? iter_value : ", compile(scope, reduction->combination), ";")); - ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=i, .next=new(ast_list_t, .ast=item)), .iter=reduction->iter, .body=body, .empty=empty); + set_binding(scope, "$iter_value", new(binding_t, .type=t, .code="$$iter_value")); + ast_t *body = FakeAST(InlineCCode, CORD_all("if (is_first) {\nreduction = $$iter_value;\nis_first = no;\n} else {\nreduction = ", compile(scope, reduction->combination), ";\n}\n")); + ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body, .empty=empty); code = CORD_all(code, compile_statement(scope, loop), "\nreduction;})"); return code; } diff --git a/typecheck.c b/typecheck.c index c41c0836..2639da73 100644 --- a/typecheck.c +++ b/typecheck.c @@ -882,7 +882,30 @@ type_t *get_type(env_t *env, ast_t *ast) case Reduction: { auto reduction = Match(ast, Reduction); type_t *iter_t = get_type(env, reduction->iter); - type_t *value_t = iteration_value_type(iter_t); + + type_t *value_t; + type_t *iter_value_t = value_type(iter_t); + switch (iter_value_t->tag) { + case IntType: value_t = iter_value_t; break; + case ArrayType: value_t = Match(iter_value_t, ArrayType)->item_type; break; + case TableType: value_t = Match(iter_value_t, TableType)->key_type; break; + case FunctionType: case ClosureType: { + auto fn = iter_value_t->tag == ClosureType ? + Match(Match(iter_value_t, ClosureType)->fn, FunctionType) : Match(iter_value_t, FunctionType); + if (!fn->args || fn->args->next) + code_err(reduction->iter, "I expected this iterable to have exactly one argument, not %T", iter_value_t); + type_t *arg_type = get_arg_type(env, fn->args); + if (arg_type->tag != PointerType) + code_err(reduction->iter, "I expected this iterable to have exactly one stack reference argument, not %T", arg_type); + auto ptr = Match(arg_type, PointerType); + if (!ptr->is_stack || ptr->is_optional || ptr->is_readonly) + code_err(reduction->iter, "I expected this iterable to have exactly one stack reference argument, not %T", arg_type); + value_t = ptr->pointed; + break; + } + default: code_err(reduction->iter, "I don't know how to do a reduction over %T values", iter_t); + } + env_t *scope = fresh_scope(env); set_binding(scope, "$reduction", new(binding_t, .type=value_t, .code="reduction")); set_binding(scope, "$iter_value", new(binding_t, .type=value_t, .code="iter_value")); @@ -515,23 +515,4 @@ type_t *get_field_type(type_t *t, const char *field_name) } } -type_t *iteration_key_type(type_t *iterable) -{ - switch (iterable->tag) { - case IntType: case ArrayType: return Type(IntType, .bits=64); - case TableType: return Match(iterable, TableType)->key_type; - default: return NULL; - } -} - -type_t *iteration_value_type(type_t *iterable) -{ - switch (iterable->tag) { - case IntType: return iterable; - case ArrayType: return Match(iterable, ArrayType)->item_type; - case TableType: return Match(iterable, TableType)->value_type; - default: return NULL; - } -} - // vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0 @@ -136,8 +136,6 @@ bool can_have_cycles(type_t *t); type_t *replace_type(type_t *t, type_t *target, type_t *replacement); size_t type_size(type_t *t); size_t type_align(type_t *t); -type_t *iteration_key_type(type_t *iterable); -type_t *iteration_value_type(type_t *iterable); type_t *get_field_type(type_t *t, const char *field_name); // vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0 |
