aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compile.c8
-rw-r--r--typecheck.c25
-rw-r--r--types.c19
-rw-r--r--types.h2
4 files changed, 28 insertions, 26 deletions
diff --git a/compile.c b/compile.c
index cc848e80..f56604d8 100644
--- a/compile.c
+++ b/compile.c
@@ -2021,6 +2021,7 @@ CORD compile(env_t *env, ast_t *ast)
CORD code = CORD_all(
"({ // Reduction:\n",
compile_declaration(t, "reduction"), ";\n"
+ "Bool_t is_first = yes;\n"
);
env_t *scope = fresh_scope(env);
ast_t *result = FakeAST(Var, "$reduction");
@@ -2040,11 +2041,10 @@ CORD compile(env_t *env, ast_t *ast)
Text$quoted(ast->file->filename, false), (long)(reduction->iter->start - reduction->iter->file->text),
(long)(reduction->iter->end - reduction->iter->file->text)));
}
- ast_t *i = FakeAST(Var, "$i");
ast_t *item = FakeAST(Var, "$iter_value");
- set_binding(scope, "$iter_value", new(binding_t, .type=t, .code="iter_value"));
- ast_t *body = FakeAST(InlineCCode, CORD_all("reduction = $$i == 1 ? iter_value : ", compile(scope, reduction->combination), ";"));
- ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=i, .next=new(ast_list_t, .ast=item)), .iter=reduction->iter, .body=body, .empty=empty);
+ set_binding(scope, "$iter_value", new(binding_t, .type=t, .code="$$iter_value"));
+ ast_t *body = FakeAST(InlineCCode, CORD_all("if (is_first) {\nreduction = $$iter_value;\nis_first = no;\n} else {\nreduction = ", compile(scope, reduction->combination), ";\n}\n"));
+ ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body, .empty=empty);
code = CORD_all(code, compile_statement(scope, loop), "\nreduction;})");
return code;
}
diff --git a/typecheck.c b/typecheck.c
index c41c0836..2639da73 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -882,7 +882,30 @@ type_t *get_type(env_t *env, ast_t *ast)
case Reduction: {
auto reduction = Match(ast, Reduction);
type_t *iter_t = get_type(env, reduction->iter);
- type_t *value_t = iteration_value_type(iter_t);
+
+ type_t *value_t;
+ type_t *iter_value_t = value_type(iter_t);
+ switch (iter_value_t->tag) {
+ case IntType: value_t = iter_value_t; break;
+ case ArrayType: value_t = Match(iter_value_t, ArrayType)->item_type; break;
+ case TableType: value_t = Match(iter_value_t, TableType)->key_type; break;
+ case FunctionType: case ClosureType: {
+ auto fn = iter_value_t->tag == ClosureType ?
+ Match(Match(iter_value_t, ClosureType)->fn, FunctionType) : Match(iter_value_t, FunctionType);
+ if (!fn->args || fn->args->next)
+ code_err(reduction->iter, "I expected this iterable to have exactly one argument, not %T", iter_value_t);
+ type_t *arg_type = get_arg_type(env, fn->args);
+ if (arg_type->tag != PointerType)
+ code_err(reduction->iter, "I expected this iterable to have exactly one stack reference argument, not %T", arg_type);
+ auto ptr = Match(arg_type, PointerType);
+ if (!ptr->is_stack || ptr->is_optional || ptr->is_readonly)
+ code_err(reduction->iter, "I expected this iterable to have exactly one stack reference argument, not %T", arg_type);
+ value_t = ptr->pointed;
+ break;
+ }
+ default: code_err(reduction->iter, "I don't know how to do a reduction over %T values", iter_t);
+ }
+
env_t *scope = fresh_scope(env);
set_binding(scope, "$reduction", new(binding_t, .type=value_t, .code="reduction"));
set_binding(scope, "$iter_value", new(binding_t, .type=value_t, .code="iter_value"));
diff --git a/types.c b/types.c
index 85ae9a3c..20049583 100644
--- a/types.c
+++ b/types.c
@@ -515,23 +515,4 @@ type_t *get_field_type(type_t *t, const char *field_name)
}
}
-type_t *iteration_key_type(type_t *iterable)
-{
- switch (iterable->tag) {
- case IntType: case ArrayType: return Type(IntType, .bits=64);
- case TableType: return Match(iterable, TableType)->key_type;
- default: return NULL;
- }
-}
-
-type_t *iteration_value_type(type_t *iterable)
-{
- switch (iterable->tag) {
- case IntType: return iterable;
- case ArrayType: return Match(iterable, ArrayType)->item_type;
- case TableType: return Match(iterable, TableType)->value_type;
- default: return NULL;
- }
-}
-
// vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0
diff --git a/types.h b/types.h
index 3debb575..668152c4 100644
--- a/types.h
+++ b/types.h
@@ -136,8 +136,6 @@ bool can_have_cycles(type_t *t);
type_t *replace_type(type_t *t, type_t *target, type_t *replacement);
size_t type_size(type_t *t);
size_t type_align(type_t *t);
-type_t *iteration_key_type(type_t *iterable);
-type_t *iteration_value_type(type_t *iterable);
type_t *get_field_type(type_t *t, const char *field_name);
// vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0