aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2024-11-04 15:24:10 -0500
committerBruce Hill <bruce@bruce-hill.com>2024-11-04 15:24:10 -0500
commit2fa26e6af3ec1599396d9260ef44b0d035b1f686 (patch)
treef35124d2259af4183b222ef0f84a89f9286e9ea2
parentda5bd87c135749b11c866aaf341c6c2c7c2ab9b2 (diff)
Be much more permissive about using integer literals for fixed-size ints
or nums or bytes
-rw-r--r--compile.c87
-rw-r--r--examples/base64/base64.tm83
-rw-r--r--typecheck.c7
-rw-r--r--types.c2
4 files changed, 96 insertions, 83 deletions
diff --git a/compile.c b/compile.c
index 072b4dca..5f274335 100644
--- a/compile.c
+++ b/compile.c
@@ -1590,6 +1590,8 @@ CORD compile_to_pointer_depth(env_t *env, ast_t *ast, int64_t target_depth, bool
static CORD compile_to_type(env_t *env, ast_t *ast, type_t *t)
{
+ if (ast->tag == Int && is_numeric_type(t))
+ return compile_int_to_type(env, ast, t);
CORD code = compile(env, ast);
type_t *actual = get_type(env, ast);
if (!promote(env, &code, actual, t))
@@ -1636,29 +1638,37 @@ CORD compile_int_to_type(env_t *env, ast_t *ast, type_t *target)
if (mpz_cmp_si(i, UINT8_MAX) <= 0 && mpz_cmp_si(i, 0) >= 0)
return CORD_asprintf("(Byte_t)(%s)", Match(ast, Int)->str);
code_err(ast, "This integer cannot fit in a byte");
+ } else if (target->tag == NumType) {
+ if (Match(target, NumType)->bits == TYPE_NBITS64) {
+ return CORD_asprintf("N64(%s)", Match(ast, Int)->str);
+ } else {
+ return CORD_asprintf("N32(%s)", Match(ast, Int)->str);
+ }
+ } else if (target->tag == IntType) {
+ int64_t target_bits = (int64_t)Match(target, IntType)->bits;
+ switch (target_bits) {
+ case TYPE_IBITS64:
+ if (mpz_cmp_si(i, INT64_MAX) <= 0 && mpz_cmp_si(i, INT64_MIN) >= 0)
+ return CORD_asprintf("I64(%s)", Match(ast, Int)->str);
+ break;
+ case TYPE_IBITS32:
+ if (mpz_cmp_si(i, INT32_MAX) <= 0 && mpz_cmp_si(i, INT32_MIN) >= 0)
+ return CORD_asprintf("I32(%s)", Match(ast, Int)->str);
+ break;
+ case TYPE_IBITS16:
+ if (mpz_cmp_si(i, INT16_MAX) <= 0 && mpz_cmp_si(i, INT16_MIN) >= 0)
+ return CORD_asprintf("I16(%s)", Match(ast, Int)->str);
+ break;
+ case TYPE_IBITS8:
+ if (mpz_cmp_si(i, INT8_MAX) <= 0 && mpz_cmp_si(i, INT8_MIN) >= 0)
+ return CORD_asprintf("I8(%s)", Match(ast, Int)->str);
+ break;
+ default: break;
+ }
+ code_err(ast, "This integer cannot fit in a %d-bit value", target_bits);
+ } else {
+ code_err(ast, "I don't know how to compile this to a %T", target);
}
-
- int64_t target_bits = (int64_t)Match(target, IntType)->bits;
- switch (target_bits) {
- case TYPE_IBITS64:
- if (mpz_cmp_si(i, INT64_MAX) <= 0 && mpz_cmp_si(i, INT64_MIN) >= 0)
- return CORD_asprintf("I64(%s)", Match(ast, Int)->str);
- break;
- case TYPE_IBITS32:
- if (mpz_cmp_si(i, INT32_MAX) <= 0 && mpz_cmp_si(i, INT32_MIN) >= 0)
- return CORD_asprintf("I32(%s)", Match(ast, Int)->str);
- break;
- case TYPE_IBITS16:
- if (mpz_cmp_si(i, INT16_MAX) <= 0 && mpz_cmp_si(i, INT16_MIN) >= 0)
- return CORD_asprintf("I16(%s)", Match(ast, Int)->str);
- break;
- case TYPE_IBITS8:
- if (mpz_cmp_si(i, INT8_MAX) <= 0 && mpz_cmp_si(i, INT8_MIN) >= 0)
- return CORD_asprintf("I8(%s)", Match(ast, Int)->str);
- break;
- default: break;
- }
- code_err(ast, "This integer cannot fit in a %d-bit value", target_bits);
}
CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t *call_args)
@@ -2079,26 +2089,21 @@ CORD compile(env_t *env, ast_t *ast)
}
}
- CORD lhs = compile(env, binop->lhs);
-
- // Special case for bit shifting by an integer literal:
- if (binop->op == BINOP_LSHIFT || binop->op == BINOP_RSHIFT || binop->op == BINOP_ULSHIFT || binop->op == BINOP_URSHIFT) {
- if ((lhs_t->tag == IntType || lhs_t->tag == ByteType) && rhs_t->tag == BigIntType && binop->rhs->tag == Int) {
- CORD shift_amount = compile_int_to_type(env, binop->rhs, lhs_t);
- if (binop->op == BINOP_LSHIFT)
- return CORD_all("(", lhs, " << ", shift_amount, ")");
- else if (binop->op == BINOP_ULSHIFT)
- return CORD_all("(", compile_type(lhs_t), ")((", compile_unsigned_type(lhs_t), ")", lhs, " << ", shift_amount, ")");
- else if (binop->op == BINOP_RSHIFT)
- return CORD_all("(", lhs, " >> ", shift_amount, ")");
- else if (binop->op == BINOP_URSHIFT)
- return CORD_all("(", compile_type(lhs_t), ")((", compile_unsigned_type(lhs_t), ")", lhs, " >> ", shift_amount, ")");
- }
+ CORD lhs, rhs;
+ if (lhs_t->tag == BigIntType && rhs_t->tag != BigIntType && is_numeric_type(rhs_t) && binop->lhs->tag == Int) {
+ lhs = compile_int_to_type(env, binop->lhs, rhs_t);
+ lhs_t = rhs_t;
+ rhs = compile(env, binop->rhs);
+ } else if (rhs_t->tag == BigIntType && lhs_t->tag != BigIntType && is_numeric_type(lhs_t) && binop->rhs->tag == Int) {
+ lhs = compile(env, binop->lhs);
+ rhs = compile_int_to_type(env, binop->rhs, lhs_t);
+ rhs_t = lhs_t;
+ } else {
+ lhs = compile(env, binop->lhs);
+ rhs = compile(env, binop->rhs);
}
- CORD rhs = compile(env, binop->rhs);
type_t *operand_t;
-
if (promote(env, &rhs, rhs_t, lhs_t))
operand_t = lhs_t;
else if (promote(env, &lhs, lhs_t, rhs_t))
@@ -3032,7 +3037,7 @@ CORD compile(env_t *env, ast_t *ast)
type_t *actual = get_type(env, call->args->value);
arg_t *args = new(arg_t, .name="i", .type=actual); // No truncation argument
CORD arg_code = compile_arguments(env, ast, args, call->args);
- if (is_numeric_type(actual) || actual->tag == ByteType) {
+ if (is_numeric_type(actual)) {
return CORD_all(type_to_cord(actual), "_to_", type_to_cord(t), "(", arg_code, ")");
} else if (actual->tag == BoolType) {
if (t->tag == NumType) {
@@ -3045,7 +3050,7 @@ CORD compile(env_t *env, ast_t *ast)
}
} else if (t->tag == IntType || t->tag == ByteType) {
type_t *actual = get_type(env, call->args->value);
- if (is_numeric_type(actual) || actual->tag == ByteType) {
+ if (is_numeric_type(actual)) {
arg_t *args = new(arg_t, .name="i", .type=actual, .next=new(arg_t, .name="truncate", .type=Type(BoolType),
.default_val=FakeAST(Bool, false)));
CORD arg_code = compile_arguments(env, ast, args, call->args);
diff --git a/examples/base64/base64.tm b/examples/base64/base64.tm
index e765d86a..f22883cd 100644
--- a/examples/base64/base64.tm
+++ b/examples/base64/base64.tm
@@ -5,22 +5,23 @@ _enc := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/":utf8_
_EQUAL_BYTE := 0x3D[B]
_dec := [
- 255[B], 255[B], 255[B], 255[B], 255[B], 255[B], 255[B], 255[B],
- 255[B], 255[B], 255[B], 255[B], 255[B], 255[B], 255[B], 255[B],
- 255[B], 255[B], 255[B], 255[B], 255[B], 255[B], 255[B], 255[B],
- 255[B], 255[B], 255[B], 255[B], 255[B], 255[B], 255[B], 255[B],
- 255[B], 255[B], 255[B], 255[B], 255[B], 255[B], 255[B], 255[B],
- 255[B], 255[B], 255[B], 62[B], 255[B], 255[B], 255[B], 63[B],
- 52[B], 53[B], 54[B], 55[B], 56[B], 57[B], 58[B], 59[B],
- 60[B], 61[B], 255[B], 255[B], 255[B], 255[B], 255[B], 255[B],
- 255[B], 0[B], 1[B], 2[B], 3[B], 4[B], 5[B], 6[B],
- 7[B], 8[B], 9[B], 10[B], 11[B], 12[B], 13[B], 14[B],
- 15[B], 16[B], 17[B], 18[B], 19[B], 20[B], 21[B], 22[B],
- 23[B], 24[B], 25[B], 255[B], 255[B], 255[B], 255[B], 255[B],
- 255[B], 26[B], 27[B], 28[B], 29[B], 30[B], 31[B], 32[B],
- 33[B], 34[B], 35[B], 36[B], 37[B], 38[B], 39[B], 40[B],
- 41[B], 42[B], 43[B], 44[B], 45[B], 46[B], 47[B], 48[B],
- 49[B], 50[B], 51[B], 255[B], 255[B], 255[B], 255[B], 255[B],
+ :Byte
+ 255, 255, 255, 255, 255, 255, 255, 255,
+ 255, 255, 255, 255, 255, 255, 255, 255,
+ 255, 255, 255, 255, 255, 255, 255, 255,
+ 255, 255, 255, 255, 255, 255, 255, 255,
+ 255, 255, 255, 255, 255, 255, 255, 255,
+ 255, 255, 255, 62, 255, 255, 255, 63,
+ 52, 53, 54, 55, 56, 57, 58, 59,
+ 60, 61, 255, 255, 255, 255, 255, 255,
+ 255, 0, 1, 2, 3, 4, 5, 6,
+ 7, 8, 9, 10, 11, 12, 13, 14,
+ 15, 16, 17, 18, 19, 20, 21, 22,
+ 23, 24, 25, 255, 255, 255, 255, 255,
+ 255, 26, 27, 28, 29, 30, 31, 32,
+ 33, 34, 35, 36, 37, 38, 39, 40,
+ 41, 42, 43, 44, 45, 46, 47, 48,
+ 49, 50, 51, 255, 255, 255, 255, 255,
]
lang Base64:
@@ -31,32 +32,32 @@ lang Base64:
output := [0[B] for _ in bytes.length * 4 / 3 + 4]
src := 1[64]
dest := 1[64]
- while src + 2[64] <= bytes.length:
+ while src + 2 <= bytes.length:
chunk24 := (
- (Int32(bytes[src]) <<< 16) or (Int32(bytes[src+1[64]]) <<< 8) or Int32(bytes[src+2[64]])
+ (Int32(bytes[src]) <<< 16) or (Int32(bytes[src+1]) <<< 8) or Int32(bytes[src+2])
)
src += 3
- output[dest] = _enc[1[32] + ((chunk24 >>> 18) and 0b111111[32])]
- output[dest+1[64]] = _enc[1[32] + ((chunk24 >>> 12) and 0b111111[32])]
- output[dest+2[64]] = _enc[1[32] + ((chunk24 >>> 6) and 0b111111[32])]
- output[dest+3[64]] = _enc[1[32] + (chunk24 and 0b111111[32])]
+ output[dest] = _enc[1 + ((chunk24 >>> 18) and 0b111111)]
+ output[dest+1] = _enc[1 + ((chunk24 >>> 12) and 0b111111)]
+ output[dest+2] = _enc[1 + ((chunk24 >>> 6) and 0b111111)]
+ output[dest+3] = _enc[1 + (chunk24 and 0b111111)]
dest += 4
- if src + 1[64] == bytes.length:
+ if src + 1 == bytes.length:
chunk16 := (
- (Int32(bytes[src]) <<< 8) or Int32(bytes[src+1[64]])
+ (Int32(bytes[src]) <<< 8) or Int32(bytes[src+1])
)
- output[dest] = _enc[1[32] + ((chunk16 >>> 10) and 0b111111[32])]
- output[dest+1[64]] = _enc[1[32] + ((chunk16 >>> 4) and 0b111111[32])]
- output[dest+2[64]] = _enc[1[32] + ((chunk16 <<< 2)and 0b111111[32])]
- output[dest+3[64]] = _EQUAL_BYTE
+ output[dest] = _enc[1 + ((chunk16 >>> 10) and 0b11111)]
+ output[dest+1] = _enc[1 + ((chunk16 >>> 4) and 0b111111)]
+ output[dest+2] = _enc[1 + ((chunk16 <<< 2)and 0b111111)]
+ output[dest+3] = _EQUAL_BYTE
else if src == bytes.length:
chunk8 := Int32(bytes[src])
- output[dest] = _enc[1[32] + ((chunk8 >>> 2) and 0b111111[32])]
- output[dest+1[64]] = _enc[1[32] + ((chunk8 <<< 4) and 0b111111[32])]
- output[dest+2[64]] = _EQUAL_BYTE
- output[dest+3[64]] = _EQUAL_BYTE
+ output[dest] = _enc[1 + ((chunk8 >>> 2) and 0b111111)]
+ output[dest+1] = _enc[1 + ((chunk8 <<< 4) and 0b111111)]
+ output[dest+2] = _EQUAL_BYTE
+ output[dest+3] = _EQUAL_BYTE
return Base64.without_escaping(Text.from_bytes(output) or return !Base64)
@@ -68,21 +69,21 @@ lang Base64:
output := [0[B] for _ in bytes.length/4 * 3]
src := 1[64]
dest := 1[64]
- while src + 3[64] <= bytes.length:
+ while src + 3 <= bytes.length:
chunk24 := (
- (Int32(_dec[1[B]+bytes[src]]) <<< 18) or
- (Int32(_dec[1[B]+bytes[src+1[64]]]) <<< 12) or
- (Int32(_dec[1[B]+bytes[src+2[64]]]) <<< 6) or
- Int32(_dec[1[B]+bytes[src+3[64]]])
+ (Int32(_dec[1+bytes[src]]) <<< 18) or
+ (Int32(_dec[1+bytes[src+1]]) <<< 12) or
+ (Int32(_dec[1+bytes[src+2]]) <<< 6) or
+ Int32(_dec[1+bytes[src+3]])
)
src += 4
- output[dest] = Byte((chunk24 >>> 16) and 0xFF[32])
- output[dest+1[64]] = Byte((chunk24 >>> 8) and 0xFF[32])
- output[dest+2[64]] = Byte(chunk24 and 0xFF[32])
+ output[dest] = Byte((chunk24 >>> 16) and 0xFF)
+ output[dest+1] = Byte((chunk24 >>> 8) and 0xFF)
+ output[dest+2] = Byte(chunk24 and 0xFF)
dest += 3
- while output[-1] == 0xFF[B]:
+ while output[-1] == 0xFF:
output = output:to(-2)
return output
diff --git a/typecheck.c b/typecheck.c
index 44068f5a..38b6f7fb 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -900,6 +900,13 @@ type_t *get_type(env_t *env, ast_t *ast)
type_t *lhs_t = get_type(env, binop->lhs),
*rhs_t = get_type(env, binop->rhs);
+ if (lhs_t->tag == BigIntType && rhs_t->tag != BigIntType && is_numeric_type(rhs_t) && binop->lhs->tag == Int) {
+ lhs_t = rhs_t;
+ } else if (rhs_t->tag == BigIntType && lhs_t->tag != BigIntType && is_numeric_type(lhs_t) && binop->rhs->tag == Int) {
+
+ rhs_t = lhs_t;
+ }
+
#define binding_works(name, self, lhs_t, rhs_t, ret_t) \
({ binding_t *b = get_namespace_binding(env, self, name); \
(b && b->type->tag == FunctionType && ({ auto fn = Match(b->type, FunctionType); \
diff --git a/types.c b/types.c
index ddb3f076..8fcff830 100644
--- a/types.c
+++ b/types.c
@@ -394,7 +394,7 @@ PUREFUNC bool is_int_type(type_t *t)
PUREFUNC bool is_numeric_type(type_t *t)
{
- return t->tag == IntType || t->tag == BigIntType || t->tag == NumType;
+ return t->tag == IntType || t->tag == BigIntType || t->tag == NumType || t->tag == ByteType;
}
PUREFUNC size_t unpadded_struct_size(type_t *t)