aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--builtins/functions.c4
-rw-r--r--compile.c47
-rw-r--r--parse.c21
-rw-r--r--structs.c9
-rw-r--r--tomo.h19
-rw-r--r--typecheck.c4
6 files changed, 61 insertions, 43 deletions
diff --git a/builtins/functions.c b/builtins/functions.c
index 8277fbc2..21f02052 100644
--- a/builtins/functions.c
+++ b/builtins/functions.c
@@ -89,6 +89,10 @@ public int32_t generic_compare(const void *x, const void *y, const TypeInfo *typ
return type->CustomInfo.compare(x, y, type);
default:
compare_data:
+ {
+ int diff = memcmp((void*)x, (void*)y, type->size);
+ printf("GOT DIFF: %d\n", diff);
+ }
return (int32_t)memcmp((void*)x, (void*)y, type->size);
}
}
diff --git a/compile.c b/compile.c
index 67aa8d71..b79b77c4 100644
--- a/compile.c
+++ b/compile.c
@@ -518,11 +518,36 @@ CORD compile(env_t *env, ast_t *ast)
}
return CORD_cat(code, "\n}");
}
- case Min: {
- return CORD_asprintf("min(%r, %r)", compile(env, Match(ast, Min)->lhs), compile(env, Match(ast, Min)->rhs));
- }
- case Max: {
- return CORD_asprintf("max(%r, %r)", compile(env, Match(ast, Max)->lhs), compile(env, Match(ast, Max)->rhs));
+ case Min: case Max: {
+ type_t *t = get_type(env, ast);
+ ast_t *key = ast->tag == Min ? Match(ast, Min)->key : Match(ast, Max)->key;
+ ast_t *lhs = ast->tag == Min ? Match(ast, Min)->lhs : Match(ast, Max)->lhs;
+ ast_t *rhs = ast->tag == Min ? Match(ast, Min)->rhs : Match(ast, Max)->rhs;
+ const char *key_name = ast->tag == Min ? "_min_" : "_max_";
+ if (key == NULL) key = FakeAST(Var, key_name);
+
+ env_t *expr_env = fresh_scope(env);
+ set_binding(expr_env, key_name, new(binding_t, .type=t, .code="$ternary$lhs"));
+ CORD lhs_key = compile(expr_env, key);
+
+ set_binding(expr_env, key_name, new(binding_t, .type=t, .code="$ternary$rhs"));
+ CORD rhs_key = compile(expr_env, key);
+
+ type_t *key_t = get_type(expr_env, key);
+ CORD comparison;
+ if (key_t->tag == IntType || key_t->tag == NumType || key_t->tag == BoolType || key_t->tag == PointerType)
+ comparison = CORD_all("((", lhs_key, ")", (ast->tag == Min ? "<=" : ">="), "(", rhs_key, "))");
+ else if (key_t->tag == TextType)
+ comparison = CORD_all("CORD_cmp(", lhs_key, ", ", rhs_key, ")", (ast->tag == Min ? "<=" : ">="), "0");
+ else
+ comparison = CORD_all("generic_compare($stack(", lhs_key, "), $stack(", rhs_key, "), ", compile_type_info(env, key_t), ")",
+ (ast->tag == Min ? "<=" : ">="), "0");
+
+ return CORD_all(
+ "({\n",
+ compile_type(t), " $ternary$lhs = ", compile(env, lhs), ", $ternary$rhs = ", compile(env, rhs), ";\n",
+ comparison, " ? $ternary$lhs : $ternary$rhs;\n"
+ "})");
}
// Min, Max,
case Array: {
@@ -794,11 +819,11 @@ CORD compile(env_t *env, ast_t *ast)
type_t *t = get_type(env, ast);
CORD code = CORD_all(
"({ // Reduction:\n",
- compile_type(t), " $lhs;\n"
+ compile_type(t), " $reduction;\n"
);
env_t *scope = fresh_scope(env);
- ast_t *result = FakeAST(Var, "$lhs");
- set_binding(scope, "$lhs", new(binding_t, .type=t));
+ ast_t *result = FakeAST(Var, "$reduction");
+ set_binding(scope, "$reduction", new(binding_t, .type=t));
ast_t *empty = NULL;
if (reduction->fallback) {
type_t *fallback_type = get_type(scope, reduction->fallback);
@@ -815,14 +840,14 @@ CORD compile(env_t *env, ast_t *ast)
(long)(reduction->iter->end - reduction->iter->file->text)));
}
ast_t *i = FakeAST(Var, "$i");
- ast_t *item = FakeAST(Var, "$rhs");
+ ast_t *item = FakeAST(Var, "$iter_value");
ast_t *body = FakeAST(
If, .condition=FakeAST(BinaryOp, .lhs=i, .op=BINOP_EQ, .rhs=FakeAST(Int, .i=1, .bits=64)),
.body=FakeAST(Assign, .targets=new(ast_list_t, .ast=result), .values=new(ast_list_t, .ast=item)),
.else_body=FakeAST(Assign, .targets=new(ast_list_t, .ast=result), .values=new(ast_list_t, .ast=reduction->combination)));
ast_t *loop = FakeAST(For, .index=i, .value=item, .iter=reduction->iter, .body=body, .empty=empty);
- set_binding(scope, "$rhs", new(binding_t, .type=t));
- code = CORD_all(code, compile(scope, loop), "\n$lhs;})");
+ set_binding(scope, "$iter_value", new(binding_t, .type=t));
+ code = CORD_all(code, compile(scope, loop), "\n$reduction;})");
return code;
}
case Skip: {
diff --git a/parse.c b/parse.c
index 5d1bcb20..022c12cf 100644
--- a/parse.c
+++ b/parse.c
@@ -734,23 +734,24 @@ PARSER(parse_reduction) {
if (op == BINOP_UNKNOWN) return NULL;
ast_t *combination;
- ast_t *lhs = NewAST(ctx->file, pos, pos, Var, .name="$lhs");
- ast_t *rhs = NewAST(ctx->file, pos, pos, Var, .name="$rhs");
+ ast_t *lhs = NewAST(ctx->file, pos, pos, Var, .name="$reduction");
+ ast_t *rhs = NewAST(ctx->file, pos, pos, Var, .name="$iter_value");
if (op == BINOP_MIN || op == BINOP_MAX) {
+ ast_t *key = NewAST(ctx->file, pos, pos, Var, .name=(op == BINOP_MIN ? "_min_" : "_max_"));
for (bool progress = true; progress; ) {
ast_t *new_term;
progress = (false
- || (new_term=parse_index_suffix(ctx, rhs))
- || (new_term=parse_field_suffix(ctx, rhs))
- || (new_term=parse_fncall_suffix(ctx, rhs, NORMAL_FUNCTION))
+ || (new_term=parse_index_suffix(ctx, key))
+ || (new_term=parse_field_suffix(ctx, key))
+ || (new_term=parse_fncall_suffix(ctx, key, NORMAL_FUNCTION))
);
- if (progress) rhs = new_term;
+ if (progress) key = new_term;
}
- if (rhs->tag == Var) rhs = NULL;
- else pos = rhs->end;
+ if (key->tag == Var) key = NULL;
+ else pos = key->end;
combination = op == BINOP_MIN ?
- NewAST(ctx->file, combo_start, pos, Min, .lhs=lhs, .rhs=lhs, .key=rhs)
- : NewAST(ctx->file, combo_start, pos, Max, .lhs=lhs, .rhs=lhs, .key=rhs);
+ NewAST(ctx->file, combo_start, pos, Min, .lhs=lhs, .rhs=rhs, .key=key)
+ : NewAST(ctx->file, combo_start, pos, Max, .lhs=lhs, .rhs=rhs, .key=key);
} else {
combination = NewAST(ctx->file, combo_start, pos, BinaryOp, .op=op, .lhs=lhs, .rhs=rhs);
}
diff --git a/structs.c b/structs.c
index 0ec08e5c..42bf291b 100644
--- a/structs.c
+++ b/structs.c
@@ -147,17 +147,16 @@ void compile_struct_def(env_t *env, ast_t *ast)
CORD typeinfo = CORD_asprintf("public const TypeInfo %s = {%zu, %zu, {.tag=CustomInfo, .CustomInfo={",
def->name, type_size(t), type_align(t));
- typeinfo = CORD_all(typeinfo, ".as_text=(void*)", def->name, "$as_text, ");
- env->code->funcs = CORD_all(env->code->funcs, compile_str_method(env, ast));
+ typeinfo = CORD_all(typeinfo, ".as_text=(void*)", def->name, "$as_text, .compare=(void*)", def->name, "$compare, ");
+ env->code->funcs = CORD_all(env->code->funcs, compile_str_method(env, ast), compile_compare_method(env, ast));
if (!t || !is_plain_data(env, t)) {
env->code->funcs = CORD_all(
- env->code->funcs, compile_equals_method(env, ast), compile_compare_method(env, ast),
+ env->code->funcs, compile_equals_method(env, ast),
compile_hash_method(env, ast));
typeinfo = CORD_all(
typeinfo,
".equal=(void*)", def->name, "$equal, "
- ".hash=(void*)", def->name, "$hash, "
- ".compare=(void*)", def->name, "$compare");
+ ".hash=(void*)", def->name, "$hash");
}
typeinfo = CORD_cat(typeinfo, "}}};\n");
env->code->typeinfos = CORD_all(env->code->typeinfos, typeinfo);
diff --git a/tomo.h b/tomo.h
index 2a0ba1f8..9337d744 100644
--- a/tomo.h
+++ b/tomo.h
@@ -57,20 +57,9 @@ CORD as_cord(void *x, bool use_color, const char *fmt, ...);
#define xor(x, y) _Generic(x, bool: (bool)((x) ^ (y)), default: ((x) ^ (y)))
#define mod(x, n) ((x) % (n))
#define mod1(x, n) (((x) % (n)) + (__typeof(x))1)
-#define $cmp(x, y) (_Generic(x, CORD: CORD_cmp(x, y), char*: strcmp(x, y), const char*: strcmp(x, y), default: (x > 0) - (y > 0)))
-#define $lt(x, y) (bool)(_Generic(x, int8_t: x < y, int16_t: x < y, int32_t: x < y, int64_t: x < y, float: x < y, double: x < y, bool: x < y, \
- default: $cmp(x, y) < 0))
-#define $le(x, y) (bool)(_Generic(x, int8_t: x <= y, int16_t: x <= y, int32_t: x <= y, int64_t: x <= y, float: x <= y, double: x <= y, bool: x <= y, \
- default: $cmp(x, y) <= 0))
-#define $ge(x, y) (bool)(_Generic(x, int8_t: x >= y, int16_t: x >= y, int32_t: x >= y, int64_t: x >= y, float: x >= y, double: x >= y, bool: x >= y, \
- default: $cmp(x, y) >= 0))
-#define $gt(x, y) (bool)(_Generic(x, int8_t: x > y, int16_t: x > y, int32_t: x > y, int64_t: x > y, float: x > y, double: x > y, bool: x > y, \
- default: $cmp(x, y) > 0))
-#define $eq(x, y) (bool)(_Generic(x, int8_t: x == y, int16_t: x == y, int32_t: x == y, int64_t: x == y, float: x == y, double: x == y, bool: x == y, \
- default: $cmp(x, y) == 0))
-#define $ne(x, y) (bool)(_Generic(x, int8_t: x != y, int16_t: x != y, int32_t: x != y, int64_t: x != y, float: x != y, double: x != y, bool: x != y, \
- default: $cmp(x, y) != 0))
-#define min(x, y) ({ $var($min_lhs, x); $var($min_rhs, y); $le($min_lhs, $min_rhs) ? $min_lhs : $min_rhs; })
-#define max(x, y) ({ $var($min_lhs, x); $var($min_rhs, y); $ge($min_lhs, $min_rhs) ? $min_lhs : $min_rhs; })
+#define $cmp(x, y, info) (_Generic(x, int8_t: (x>0)-(y>0), int16_t: (x>0)-(y>0), int32_t: (x>0)-(y>0), int64_t: (x>0)-(y>0), bool: (x>0)-(y>0), \
+ CORD: CORD_cmp((CORD)x, (CORD)y), char*: strcmp((char*)x, (char*)y), default: generic_compare($stack(x), $stack(y), info)))
+#define min(c_type, x, y, info) ({ c_type $lhs = x, $rhs = y; generic_compare(&$lhs, &$rhs, info) <= 0 ? $lhs : $rhs; })
+#define max(c_type, x, y, info) ({ c_type $lhs = x, $rhs = y; generic_compare(&$lhs, &$rhs, info) >= 0 ? $lhs : $rhs; })
// vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0
diff --git a/typecheck.c b/typecheck.c
index 165e74dd..6528b4b3 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -549,8 +549,8 @@ type_t *get_type(env_t *env, ast_t *ast)
type_t *iter_t = get_type(env, reduction->iter);
type_t *value_t = iteration_value_type(iter_t);
env_t *scope = fresh_scope(env);
- set_binding(scope, "$lhs", new(binding_t, .type=value_t));
- set_binding(scope, "$rhs", new(binding_t, .type=value_t));
+ set_binding(scope, "$reduction", new(binding_t, .type=value_t, .code="$reduction"));
+ set_binding(scope, "$iter_value", new(binding_t, .type=value_t, .code="$iter_value"));
type_t *t = get_type(scope, reduction->combination);
if (!reduction->fallback)
return t;