diff options
| author | Bruce Hill <bruce@bruce-hill.com> | 2024-03-05 14:46:01 -0500 |
|---|---|---|
| committer | Bruce Hill <bruce@bruce-hill.com> | 2024-03-05 14:46:01 -0500 |
| commit | 38d5245a9af5bb2aa5baee5df71d8a80fd55dd07 (patch) | |
| tree | 7633f0396819cdd35692c922d3e8f047f1bc3c4f | |
| parent | 558c8588ee2aa772442837be16a7ed19a36cc753 (diff) | |
Fix up some min/max stuff
| -rw-r--r-- | builtins/functions.c | 4 | ||||
| -rw-r--r-- | compile.c | 47 | ||||
| -rw-r--r-- | parse.c | 21 | ||||
| -rw-r--r-- | structs.c | 9 | ||||
| -rw-r--r-- | tomo.h | 19 | ||||
| -rw-r--r-- | typecheck.c | 4 |
6 files changed, 61 insertions, 43 deletions
diff --git a/builtins/functions.c b/builtins/functions.c index 8277fbc2..21f02052 100644 --- a/builtins/functions.c +++ b/builtins/functions.c @@ -89,6 +89,10 @@ public int32_t generic_compare(const void *x, const void *y, const TypeInfo *typ return type->CustomInfo.compare(x, y, type); default: compare_data: + { + int diff = memcmp((void*)x, (void*)y, type->size); + printf("GOT DIFF: %d\n", diff); + } return (int32_t)memcmp((void*)x, (void*)y, type->size); } } @@ -518,11 +518,36 @@ CORD compile(env_t *env, ast_t *ast) } return CORD_cat(code, "\n}"); } - case Min: { - return CORD_asprintf("min(%r, %r)", compile(env, Match(ast, Min)->lhs), compile(env, Match(ast, Min)->rhs)); - } - case Max: { - return CORD_asprintf("max(%r, %r)", compile(env, Match(ast, Max)->lhs), compile(env, Match(ast, Max)->rhs)); + case Min: case Max: { + type_t *t = get_type(env, ast); + ast_t *key = ast->tag == Min ? Match(ast, Min)->key : Match(ast, Max)->key; + ast_t *lhs = ast->tag == Min ? Match(ast, Min)->lhs : Match(ast, Max)->lhs; + ast_t *rhs = ast->tag == Min ? Match(ast, Min)->rhs : Match(ast, Max)->rhs; + const char *key_name = ast->tag == Min ? "_min_" : "_max_"; + if (key == NULL) key = FakeAST(Var, key_name); + + env_t *expr_env = fresh_scope(env); + set_binding(expr_env, key_name, new(binding_t, .type=t, .code="$ternary$lhs")); + CORD lhs_key = compile(expr_env, key); + + set_binding(expr_env, key_name, new(binding_t, .type=t, .code="$ternary$rhs")); + CORD rhs_key = compile(expr_env, key); + + type_t *key_t = get_type(expr_env, key); + CORD comparison; + if (key_t->tag == IntType || key_t->tag == NumType || key_t->tag == BoolType || key_t->tag == PointerType) + comparison = CORD_all("((", lhs_key, ")", (ast->tag == Min ? "<=" : ">="), "(", rhs_key, "))"); + else if (key_t->tag == TextType) + comparison = CORD_all("CORD_cmp(", lhs_key, ", ", rhs_key, ")", (ast->tag == Min ? "<=" : ">="), "0"); + else + comparison = CORD_all("generic_compare($stack(", lhs_key, "), $stack(", rhs_key, "), ", compile_type_info(env, key_t), ")", + (ast->tag == Min ? "<=" : ">="), "0"); + + return CORD_all( + "({\n", + compile_type(t), " $ternary$lhs = ", compile(env, lhs), ", $ternary$rhs = ", compile(env, rhs), ";\n", + comparison, " ? $ternary$lhs : $ternary$rhs;\n" + "})"); } // Min, Max, case Array: { @@ -794,11 +819,11 @@ CORD compile(env_t *env, ast_t *ast) type_t *t = get_type(env, ast); CORD code = CORD_all( "({ // Reduction:\n", - compile_type(t), " $lhs;\n" + compile_type(t), " $reduction;\n" ); env_t *scope = fresh_scope(env); - ast_t *result = FakeAST(Var, "$lhs"); - set_binding(scope, "$lhs", new(binding_t, .type=t)); + ast_t *result = FakeAST(Var, "$reduction"); + set_binding(scope, "$reduction", new(binding_t, .type=t)); ast_t *empty = NULL; if (reduction->fallback) { type_t *fallback_type = get_type(scope, reduction->fallback); @@ -815,14 +840,14 @@ CORD compile(env_t *env, ast_t *ast) (long)(reduction->iter->end - reduction->iter->file->text))); } ast_t *i = FakeAST(Var, "$i"); - ast_t *item = FakeAST(Var, "$rhs"); + ast_t *item = FakeAST(Var, "$iter_value"); ast_t *body = FakeAST( If, .condition=FakeAST(BinaryOp, .lhs=i, .op=BINOP_EQ, .rhs=FakeAST(Int, .i=1, .bits=64)), .body=FakeAST(Assign, .targets=new(ast_list_t, .ast=result), .values=new(ast_list_t, .ast=item)), .else_body=FakeAST(Assign, .targets=new(ast_list_t, .ast=result), .values=new(ast_list_t, .ast=reduction->combination))); ast_t *loop = FakeAST(For, .index=i, .value=item, .iter=reduction->iter, .body=body, .empty=empty); - set_binding(scope, "$rhs", new(binding_t, .type=t)); - code = CORD_all(code, compile(scope, loop), "\n$lhs;})"); + set_binding(scope, "$iter_value", new(binding_t, .type=t)); + code = CORD_all(code, compile(scope, loop), "\n$reduction;})"); return code; } case Skip: { @@ -734,23 +734,24 @@ PARSER(parse_reduction) { if (op == BINOP_UNKNOWN) return NULL; ast_t *combination; - ast_t *lhs = NewAST(ctx->file, pos, pos, Var, .name="$lhs"); - ast_t *rhs = NewAST(ctx->file, pos, pos, Var, .name="$rhs"); + ast_t *lhs = NewAST(ctx->file, pos, pos, Var, .name="$reduction"); + ast_t *rhs = NewAST(ctx->file, pos, pos, Var, .name="$iter_value"); if (op == BINOP_MIN || op == BINOP_MAX) { + ast_t *key = NewAST(ctx->file, pos, pos, Var, .name=(op == BINOP_MIN ? "_min_" : "_max_")); for (bool progress = true; progress; ) { ast_t *new_term; progress = (false - || (new_term=parse_index_suffix(ctx, rhs)) - || (new_term=parse_field_suffix(ctx, rhs)) - || (new_term=parse_fncall_suffix(ctx, rhs, NORMAL_FUNCTION)) + || (new_term=parse_index_suffix(ctx, key)) + || (new_term=parse_field_suffix(ctx, key)) + || (new_term=parse_fncall_suffix(ctx, key, NORMAL_FUNCTION)) ); - if (progress) rhs = new_term; + if (progress) key = new_term; } - if (rhs->tag == Var) rhs = NULL; - else pos = rhs->end; + if (key->tag == Var) key = NULL; + else pos = key->end; combination = op == BINOP_MIN ? - NewAST(ctx->file, combo_start, pos, Min, .lhs=lhs, .rhs=lhs, .key=rhs) - : NewAST(ctx->file, combo_start, pos, Max, .lhs=lhs, .rhs=lhs, .key=rhs); + NewAST(ctx->file, combo_start, pos, Min, .lhs=lhs, .rhs=rhs, .key=key) + : NewAST(ctx->file, combo_start, pos, Max, .lhs=lhs, .rhs=rhs, .key=key); } else { combination = NewAST(ctx->file, combo_start, pos, BinaryOp, .op=op, .lhs=lhs, .rhs=rhs); } @@ -147,17 +147,16 @@ void compile_struct_def(env_t *env, ast_t *ast) CORD typeinfo = CORD_asprintf("public const TypeInfo %s = {%zu, %zu, {.tag=CustomInfo, .CustomInfo={", def->name, type_size(t), type_align(t)); - typeinfo = CORD_all(typeinfo, ".as_text=(void*)", def->name, "$as_text, "); - env->code->funcs = CORD_all(env->code->funcs, compile_str_method(env, ast)); + typeinfo = CORD_all(typeinfo, ".as_text=(void*)", def->name, "$as_text, .compare=(void*)", def->name, "$compare, "); + env->code->funcs = CORD_all(env->code->funcs, compile_str_method(env, ast), compile_compare_method(env, ast)); if (!t || !is_plain_data(env, t)) { env->code->funcs = CORD_all( - env->code->funcs, compile_equals_method(env, ast), compile_compare_method(env, ast), + env->code->funcs, compile_equals_method(env, ast), compile_hash_method(env, ast)); typeinfo = CORD_all( typeinfo, ".equal=(void*)", def->name, "$equal, " - ".hash=(void*)", def->name, "$hash, " - ".compare=(void*)", def->name, "$compare"); + ".hash=(void*)", def->name, "$hash"); } typeinfo = CORD_cat(typeinfo, "}}};\n"); env->code->typeinfos = CORD_all(env->code->typeinfos, typeinfo); @@ -57,20 +57,9 @@ CORD as_cord(void *x, bool use_color, const char *fmt, ...); #define xor(x, y) _Generic(x, bool: (bool)((x) ^ (y)), default: ((x) ^ (y))) #define mod(x, n) ((x) % (n)) #define mod1(x, n) (((x) % (n)) + (__typeof(x))1) -#define $cmp(x, y) (_Generic(x, CORD: CORD_cmp(x, y), char*: strcmp(x, y), const char*: strcmp(x, y), default: (x > 0) - (y > 0))) -#define $lt(x, y) (bool)(_Generic(x, int8_t: x < y, int16_t: x < y, int32_t: x < y, int64_t: x < y, float: x < y, double: x < y, bool: x < y, \ - default: $cmp(x, y) < 0)) -#define $le(x, y) (bool)(_Generic(x, int8_t: x <= y, int16_t: x <= y, int32_t: x <= y, int64_t: x <= y, float: x <= y, double: x <= y, bool: x <= y, \ - default: $cmp(x, y) <= 0)) -#define $ge(x, y) (bool)(_Generic(x, int8_t: x >= y, int16_t: x >= y, int32_t: x >= y, int64_t: x >= y, float: x >= y, double: x >= y, bool: x >= y, \ - default: $cmp(x, y) >= 0)) -#define $gt(x, y) (bool)(_Generic(x, int8_t: x > y, int16_t: x > y, int32_t: x > y, int64_t: x > y, float: x > y, double: x > y, bool: x > y, \ - default: $cmp(x, y) > 0)) -#define $eq(x, y) (bool)(_Generic(x, int8_t: x == y, int16_t: x == y, int32_t: x == y, int64_t: x == y, float: x == y, double: x == y, bool: x == y, \ - default: $cmp(x, y) == 0)) -#define $ne(x, y) (bool)(_Generic(x, int8_t: x != y, int16_t: x != y, int32_t: x != y, int64_t: x != y, float: x != y, double: x != y, bool: x != y, \ - default: $cmp(x, y) != 0)) -#define min(x, y) ({ $var($min_lhs, x); $var($min_rhs, y); $le($min_lhs, $min_rhs) ? $min_lhs : $min_rhs; }) -#define max(x, y) ({ $var($min_lhs, x); $var($min_rhs, y); $ge($min_lhs, $min_rhs) ? $min_lhs : $min_rhs; }) +#define $cmp(x, y, info) (_Generic(x, int8_t: (x>0)-(y>0), int16_t: (x>0)-(y>0), int32_t: (x>0)-(y>0), int64_t: (x>0)-(y>0), bool: (x>0)-(y>0), \ + CORD: CORD_cmp((CORD)x, (CORD)y), char*: strcmp((char*)x, (char*)y), default: generic_compare($stack(x), $stack(y), info))) +#define min(c_type, x, y, info) ({ c_type $lhs = x, $rhs = y; generic_compare(&$lhs, &$rhs, info) <= 0 ? $lhs : $rhs; }) +#define max(c_type, x, y, info) ({ c_type $lhs = x, $rhs = y; generic_compare(&$lhs, &$rhs, info) >= 0 ? $lhs : $rhs; }) // vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0 diff --git a/typecheck.c b/typecheck.c index 165e74dd..6528b4b3 100644 --- a/typecheck.c +++ b/typecheck.c @@ -549,8 +549,8 @@ type_t *get_type(env_t *env, ast_t *ast) type_t *iter_t = get_type(env, reduction->iter); type_t *value_t = iteration_value_type(iter_t); env_t *scope = fresh_scope(env); - set_binding(scope, "$lhs", new(binding_t, .type=value_t)); - set_binding(scope, "$rhs", new(binding_t, .type=value_t)); + set_binding(scope, "$reduction", new(binding_t, .type=value_t, .code="$reduction")); + set_binding(scope, "$iter_value", new(binding_t, .type=value_t, .code="$iter_value")); type_t *t = get_type(scope, reduction->combination); if (!reduction->fallback) return t; |
