aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2024-11-02 22:34:35 -0400
committerBruce Hill <bruce@bruce-hill.com>2024-11-02 22:34:35 -0400
commit92a593b80fe935eb21615dc45b4d7868b254bec6 (patch)
tree029e43ff46abc4bd86000c5a90900511ea99fcf3
parent0b7a0dd043a4c7ccfc924d618508d1edc0115e2f (diff)
Support reductions for comparison operators like == and <
-rw-r--r--compile.c119
-rw-r--r--test/reductions.tm17
-rw-r--r--typecheck.c6
-rw-r--r--types.c20
-rw-r--r--types.h1
5 files changed, 116 insertions, 47 deletions
diff --git a/compile.c b/compile.c
index ff3cdf72..d947fe26 100644
--- a/compile.c
+++ b/compile.c
@@ -3153,49 +3153,82 @@ CORD compile(env_t *env, ast_t *ast)
auto reduction = Match(ast, Reduction);
type_t *optional_t = get_type(env, ast);
type_t *t = Match(optional_t, OptionalType)->type;
- CORD code = CORD_all(
- "({ // Reduction:\n",
- compile_declaration(t, "reduction"), ";\n"
- "Bool_t has_value = no;\n"
- );
- env_t *scope = fresh_scope(env);
- set_binding(scope, "$reduction", new(binding_t, .type=t, .code="reduction"));
- ast_t *item = FakeAST(Var, "$iter_value");
- ast_t *body = FakeAST(InlineCCode, .code="{}"); // placeholder
- ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body);
- env_t *body_scope = for_scope(scope, loop);
-
- // For the special case of (or)/(and), we need to early out if we can:
- CORD early_out = CORD_EMPTY;
- if (reduction->combination->tag == BinaryOp) {
- auto binop = Match(reduction->combination, BinaryOp);
- if (t->tag != BoolType && (binop->op == BINOP_EQ || binop->op == BINOP_NE
- || binop->op == BINOP_LT || binop->op == BINOP_LE
- || binop->op == BINOP_GT || binop->op == BINOP_GE))
- code_err(ast, "Reductions are not supported for this type of infix operator");
- else if ((t->tag != IntType || Match(t, IntType)->bits != TYPE_IBITS32) && binop->op == BINOP_CMP)
- code_err(ast, "<> reductions are only supported for Int32 values");
- else if (t->tag == BoolType && binop->op == BINOP_AND)
- early_out = "if (!reduction) break;";
- else if (t->tag == BoolType && binop->op == BINOP_OR)
- early_out = "if (reduction) break;";
- else if (t->tag == OptionalType && binop->op == BINOP_AND)
- early_out = CORD_all("if (", check_null(t, "reduction"), ") break;");
- else if (t->tag == OptionalType && binop->op == BINOP_OR)
- early_out = CORD_all("if (!", check_null(t, "reduction"), ") break;");
- }
-
- body->__data.InlineCCode.code = CORD_all(
- "if (!has_value) {\n"
- " reduction = ", compile(body_scope, item), ";\n"
- " has_value = yes;\n"
- "} else {\n"
- " reduction = ", compile(body_scope, reduction->combination), ";\n",
- early_out,
- "}\n");
-
- code = CORD_all(code, compile_statement(scope, loop), "\nhas_value ? ", promote_to_optional(t, "reduction"), " : ", compile_null(t), ";})");
- return code;
+ binop_e op = reduction->combination->tag == BinaryOp ? Match(reduction->combination, BinaryOp)->op : BINOP_UNKNOWN;
+
+ if (op == BINOP_EQ || op == BINOP_NE || op == BINOP_LT || op == BINOP_LE || op == BINOP_GT || op == BINOP_GE) {
+ // Chained comparisons like ==, <, etc.
+ type_t *iter_t = get_type(env, reduction->iter);
+ type_t *item_t = get_iterated_type(iter_t);
+ if (!item_t) code_err(reduction->iter, "I couldn't figure out how to iterate over this type: %T", iter_t);
+ CORD code = CORD_all(
+ "({ // Reduction:\n",
+ compile_declaration(item_t, "prev"), ";\n"
+ "OptionalBool_t result = NULL_BOOL;\n"
+ );
+ env_t *scope = fresh_scope(env);
+ set_binding(scope, "$reduction", new(binding_t, .type=item_t, .code="prev"));
+ ast_t *item = FakeAST(Var, "$iter_value");
+ ast_t *body = FakeAST(InlineCCode, .code="{}"); // placeholder
+ ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body);
+ env_t *body_scope = for_scope(scope, loop);
+
+ body->__data.InlineCCode.code = CORD_all(
+ "if (result == NULL_BOOL) {\n"
+ " prev = ", compile(body_scope, item), ";\n"
+ " result = yes;\n"
+ "} else {\n"
+ " if (", compile(body_scope, reduction->combination), ") {\n",
+ " prev = ", compile(body_scope, item), ";\n",
+ " } else {\n"
+ " result = no;\n",
+ " break;\n",
+ " }\n",
+ "}\n");
+ code = CORD_all(code, compile_statement(scope, loop), "\nresult;})");
+ return code;
+ } else {
+ // Accumulator-style reductions like +, ++, *, etc.
+ CORD code = CORD_all(
+ "({ // Reduction:\n",
+ compile_declaration(t, "reduction"), ";\n"
+ "Bool_t has_value = no;\n"
+ );
+ env_t *scope = fresh_scope(env);
+ set_binding(scope, "$reduction", new(binding_t, .type=t, .code="reduction"));
+ ast_t *item = FakeAST(Var, "$iter_value");
+ ast_t *body = FakeAST(InlineCCode, .code="{}"); // placeholder
+ ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body);
+ env_t *body_scope = for_scope(scope, loop);
+
+ // For the special case of (or)/(and), we need to early out if we can:
+ CORD early_out = CORD_EMPTY;
+ if (op == BINOP_CMP) {
+ if (t->tag != IntType || Match(t, IntType)->bits != TYPE_IBITS32)
+ code_err(ast, "<> reductions are only supported for Int32 values");
+ } else if (op == BINOP_AND) {
+ if (t->tag == BoolType)
+ early_out = "if (!reduction) break;";
+ else if (t->tag == OptionalType)
+ early_out = CORD_all("if (", check_null(t, "reduction"), ") break;");
+ } else if (op == BINOP_OR) {
+ if (t->tag == BoolType)
+ early_out = "if (reduction) break;";
+ else if (t->tag == OptionalType)
+ early_out = CORD_all("if (!", check_null(t, "reduction"), ") break;");
+ }
+
+ body->__data.InlineCCode.code = CORD_all(
+ "if (!has_value) {\n"
+ " reduction = ", compile(body_scope, item), ";\n"
+ " has_value = yes;\n"
+ "} else {\n"
+ " reduction = ", compile(body_scope, reduction->combination), ";\n",
+ early_out,
+ "}\n");
+
+ code = CORD_all(code, compile_statement(scope, loop), "\nhas_value ? ", promote_to_optional(t, "reduction"), " : ", compile_null(t), ";})");
+ return code;
+ }
}
case FieldAccess: {
auto f = Match(ast, FieldAccess);
diff --git a/test/reductions.tm b/test/reductions.tm
index 7bfe212a..a844081f 100644
--- a/test/reductions.tm
+++ b/test/reductions.tm
@@ -27,8 +27,17 @@ func main():
= Foo(x=0, y=-999)
!! (or) and (and) have early out behavior:
- >> (or: i == 3 for i in 9999999999999999999999999999)!
- = yes
+ >> (or: i == 3 for i in 9999999999999999999999999999)!
+ = yes
- >> (and: i < 10 for i in 9999999999999999999999999999)!
- = no
+ >> (and: i < 10 for i in 9999999999999999999999999999)!
+ = no
+
+ >> (<=: [1, 2, 2, 3, 4])!
+ = yes
+
+ >> (<=: [:Int])
+ = !Bool
+
+ >> (<=: [5, 4, 3, 2, 1])!
+ = no
diff --git a/typecheck.c b/typecheck.c
index bc3c7982..bff7d080 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -1033,6 +1033,12 @@ type_t *get_type(env_t *env, ast_t *ast)
auto reduction = Match(ast, Reduction);
type_t *iter_t = get_type(env, reduction->iter);
+ if (reduction->combination && reduction->combination->tag == BinaryOp) {
+ binop_e op = Match(reduction->combination, BinaryOp)->op;
+ if (op == BINOP_EQ || op == BINOP_NE || op == BINOP_LT || op == BINOP_LE || op == BINOP_GT || op == BINOP_GE)
+ return Type(OptionalType, .type=Type(BoolType));
+ }
+
type_t *value_t;
type_t *iter_value_t = value_type(iter_t);
switch (iter_value_t->tag) {
diff --git a/types.c b/types.c
index bcb0e9f3..ddb3f076 100644
--- a/types.c
+++ b/types.c
@@ -626,4 +626,24 @@ type_t *get_field_type(type_t *t, const char *field_name)
}
}
+PUREFUNC type_t *get_iterated_type(type_t *t)
+{
+ type_t *iter_value_t = value_type(t);
+ switch (iter_value_t->tag) {
+ case BigIntType: case IntType: return iter_value_t; break;
+ case ArrayType: return Match(iter_value_t, ArrayType)->item_type; break;
+ case SetType: return Match(iter_value_t, SetType)->item_type; break;
+ case TableType: return NULL;
+ case FunctionType: case ClosureType: {
+ // Iterator function
+ auto fn = iter_value_t->tag == ClosureType ?
+ Match(Match(iter_value_t, ClosureType)->fn, FunctionType) : Match(iter_value_t, FunctionType);
+ if (fn->args || fn->ret->tag != OptionalType)
+ return NULL;
+ return Match(fn->ret, OptionalType)->type;
+ }
+ default: return NULL;
+ }
+}
+
// vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0
diff --git a/types.h b/types.h
index 96e0b6c5..c983dc24 100644
--- a/types.h
+++ b/types.h
@@ -153,5 +153,6 @@ PUREFUNC size_t type_size(type_t *t);
PUREFUNC size_t type_align(type_t *t);
PUREFUNC size_t unpadded_struct_size(type_t *t);
type_t *get_field_type(type_t *t, const char *field_name);
+PUREFUNC type_t *get_iterated_type(type_t *t);
// vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0