Support reductions for comparison operators like == and <

This commit is contained in:
Bruce Hill 2024-11-02 22:34:35 -04:00
parent 0b7a0dd043
commit 92a593b80f
5 changed files with 114 additions and 45 deletions

115
compile.c
View File

@ -3153,49 +3153,82 @@ CORD compile(env_t *env, ast_t *ast)
auto reduction = Match(ast, Reduction);
type_t *optional_t = get_type(env, ast);
type_t *t = Match(optional_t, OptionalType)->type;
CORD code = CORD_all(
"({ // Reduction:\n",
compile_declaration(t, "reduction"), ";\n"
"Bool_t has_value = no;\n"
);
env_t *scope = fresh_scope(env);
set_binding(scope, "$reduction", new(binding_t, .type=t, .code="reduction"));
ast_t *item = FakeAST(Var, "$iter_value");
ast_t *body = FakeAST(InlineCCode, .code="{}"); // placeholder
ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body);
env_t *body_scope = for_scope(scope, loop);
binop_e op = reduction->combination->tag == BinaryOp ? Match(reduction->combination, BinaryOp)->op : BINOP_UNKNOWN;
// For the special case of (or)/(and), we need to early out if we can:
CORD early_out = CORD_EMPTY;
if (reduction->combination->tag == BinaryOp) {
auto binop = Match(reduction->combination, BinaryOp);
if (t->tag != BoolType && (binop->op == BINOP_EQ || binop->op == BINOP_NE
|| binop->op == BINOP_LT || binop->op == BINOP_LE
|| binop->op == BINOP_GT || binop->op == BINOP_GE))
code_err(ast, "Reductions are not supported for this type of infix operator");
else if ((t->tag != IntType || Match(t, IntType)->bits != TYPE_IBITS32) && binop->op == BINOP_CMP)
code_err(ast, "<> reductions are only supported for Int32 values");
else if (t->tag == BoolType && binop->op == BINOP_AND)
early_out = "if (!reduction) break;";
else if (t->tag == BoolType && binop->op == BINOP_OR)
early_out = "if (reduction) break;";
else if (t->tag == OptionalType && binop->op == BINOP_AND)
early_out = CORD_all("if (", check_null(t, "reduction"), ") break;");
else if (t->tag == OptionalType && binop->op == BINOP_OR)
early_out = CORD_all("if (!", check_null(t, "reduction"), ") break;");
if (op == BINOP_EQ || op == BINOP_NE || op == BINOP_LT || op == BINOP_LE || op == BINOP_GT || op == BINOP_GE) {
// Chained comparisons like ==, <, etc.
type_t *iter_t = get_type(env, reduction->iter);
type_t *item_t = get_iterated_type(iter_t);
if (!item_t) code_err(reduction->iter, "I couldn't figure out how to iterate over this type: %T", iter_t);
CORD code = CORD_all(
"({ // Reduction:\n",
compile_declaration(item_t, "prev"), ";\n"
"OptionalBool_t result = NULL_BOOL;\n"
);
env_t *scope = fresh_scope(env);
set_binding(scope, "$reduction", new(binding_t, .type=item_t, .code="prev"));
ast_t *item = FakeAST(Var, "$iter_value");
ast_t *body = FakeAST(InlineCCode, .code="{}"); // placeholder
ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body);
env_t *body_scope = for_scope(scope, loop);
body->__data.InlineCCode.code = CORD_all(
"if (result == NULL_BOOL) {\n"
" prev = ", compile(body_scope, item), ";\n"
" result = yes;\n"
"} else {\n"
" if (", compile(body_scope, reduction->combination), ") {\n",
" prev = ", compile(body_scope, item), ";\n",
" } else {\n"
" result = no;\n",
" break;\n",
" }\n",
"}\n");
code = CORD_all(code, compile_statement(scope, loop), "\nresult;})");
return code;
} else {
// Accumulator-style reductions like +, ++, *, etc.
CORD code = CORD_all(
"({ // Reduction:\n",
compile_declaration(t, "reduction"), ";\n"
"Bool_t has_value = no;\n"
);
env_t *scope = fresh_scope(env);
set_binding(scope, "$reduction", new(binding_t, .type=t, .code="reduction"));
ast_t *item = FakeAST(Var, "$iter_value");
ast_t *body = FakeAST(InlineCCode, .code="{}"); // placeholder
ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body);
env_t *body_scope = for_scope(scope, loop);
// For the special case of (or)/(and), we need to early out if we can:
CORD early_out = CORD_EMPTY;
if (op == BINOP_CMP) {
if (t->tag != IntType || Match(t, IntType)->bits != TYPE_IBITS32)
code_err(ast, "<> reductions are only supported for Int32 values");
} else if (op == BINOP_AND) {
if (t->tag == BoolType)
early_out = "if (!reduction) break;";
else if (t->tag == OptionalType)
early_out = CORD_all("if (", check_null(t, "reduction"), ") break;");
} else if (op == BINOP_OR) {
if (t->tag == BoolType)
early_out = "if (reduction) break;";
else if (t->tag == OptionalType)
early_out = CORD_all("if (!", check_null(t, "reduction"), ") break;");
}
body->__data.InlineCCode.code = CORD_all(
"if (!has_value) {\n"
" reduction = ", compile(body_scope, item), ";\n"
" has_value = yes;\n"
"} else {\n"
" reduction = ", compile(body_scope, reduction->combination), ";\n",
early_out,
"}\n");
code = CORD_all(code, compile_statement(scope, loop), "\nhas_value ? ", promote_to_optional(t, "reduction"), " : ", compile_null(t), ";})");
return code;
}
body->__data.InlineCCode.code = CORD_all(
"if (!has_value) {\n"
" reduction = ", compile(body_scope, item), ";\n"
" has_value = yes;\n"
"} else {\n"
" reduction = ", compile(body_scope, reduction->combination), ";\n",
early_out,
"}\n");
code = CORD_all(code, compile_statement(scope, loop), "\nhas_value ? ", promote_to_optional(t, "reduction"), " : ", compile_null(t), ";})");
return code;
}
case FieldAccess: {
auto f = Match(ast, FieldAccess);

View File

@ -27,8 +27,17 @@ func main():
= Foo(x=0, y=-999)
!! (or) and (and) have early out behavior:
>> (or: i == 3 for i in 9999999999999999999999999999)!
= yes
>> (or: i == 3 for i in 9999999999999999999999999999)!
= yes
>> (and: i < 10 for i in 9999999999999999999999999999)!
= no
>> (and: i < 10 for i in 9999999999999999999999999999)!
= no
>> (<=: [1, 2, 2, 3, 4])!
= yes
>> (<=: [:Int])
= !Bool
>> (<=: [5, 4, 3, 2, 1])!
= no

View File

@ -1033,6 +1033,12 @@ type_t *get_type(env_t *env, ast_t *ast)
auto reduction = Match(ast, Reduction);
type_t *iter_t = get_type(env, reduction->iter);
if (reduction->combination && reduction->combination->tag == BinaryOp) {
binop_e op = Match(reduction->combination, BinaryOp)->op;
if (op == BINOP_EQ || op == BINOP_NE || op == BINOP_LT || op == BINOP_LE || op == BINOP_GT || op == BINOP_GE)
return Type(OptionalType, .type=Type(BoolType));
}
type_t *value_t;
type_t *iter_value_t = value_type(iter_t);
switch (iter_value_t->tag) {

20
types.c
View File

@ -626,4 +626,24 @@ type_t *get_field_type(type_t *t, const char *field_name)
}
}
PUREFUNC type_t *get_iterated_type(type_t *t)
{
type_t *iter_value_t = value_type(t);
switch (iter_value_t->tag) {
case BigIntType: case IntType: return iter_value_t; break;
case ArrayType: return Match(iter_value_t, ArrayType)->item_type; break;
case SetType: return Match(iter_value_t, SetType)->item_type; break;
case TableType: return NULL;
case FunctionType: case ClosureType: {
// Iterator function
auto fn = iter_value_t->tag == ClosureType ?
Match(Match(iter_value_t, ClosureType)->fn, FunctionType) : Match(iter_value_t, FunctionType);
if (fn->args || fn->ret->tag != OptionalType)
return NULL;
return Match(fn->ret, OptionalType)->type;
}
default: return NULL;
}
}
// vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0

View File

@ -153,5 +153,6 @@ PUREFUNC size_t type_size(type_t *t);
PUREFUNC size_t type_align(type_t *t);
PUREFUNC size_t unpadded_struct_size(type_t *t);
type_t *get_field_type(type_t *t, const char *field_name);
PUREFUNC type_t *get_iterated_type(type_t *t);
// vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0