diff options
| author | Bruce Hill <bruce@bruce-hill.com> | 2025-10-19 14:07:53 -0400 |
|---|---|---|
| committer | Bruce Hill <bruce@bruce-hill.com> | 2025-10-19 14:07:53 -0400 |
| commit | 1713c56dbf8ffaa01e559c7564928721e363ca39 (patch) | |
| tree | b1e1886076b3e44d95e082e2d1a8eb6db1437aec /src/stdlib/bigint.c | |
| parent | 3a2077067343a20f631ec36838e197b34ff422f4 (diff) | |
Move integer implementation details into separate header/C files, backed
by template headers that use an INT_BITS macro to redefine
implementations for different int sizes.
Diffstat (limited to 'src/stdlib/bigint.c')
| -rw-r--r-- | src/stdlib/bigint.c | 548 |
1 files changed, 548 insertions, 0 deletions
diff --git a/src/stdlib/bigint.c b/src/stdlib/bigint.c new file mode 100644 index 00000000..3ed3f3cf --- /dev/null +++ b/src/stdlib/bigint.c @@ -0,0 +1,548 @@ +// Integer type infos and methods + +#include <stdio.h> // Must be before gmp.h + +#include <ctype.h> +#include <gc.h> +#include <gmp.h> +#include <stdbool.h> +#include <stdint.h> +#include <stdio.h> +#include <stdlib.h> + +#include "datatypes.h" +#include "integers.h" +#include "optionals.h" +#include "print.h" +#include "siphash.h" +#include "text.h" +#include "types.h" + +public +int Int$print(FILE *f, Int_t i) { + if (likely(i.small & 1L)) { + return _print_int(f, (int64_t)((i.small) >> 2L)); + } else { + return gmp_fprintf(f, "%Zd", *i.big); + } +} + +static Text_t _int64_to_text(int64_t n) { + if (n == INT64_MIN) return Text("-9223372036854775808"); + + char buf[21] = {[20] = 0}; // Big enough for INT64_MIN + '\0' + char *p = &buf[19]; + bool negative = n < 0; + if (negative) n = -n; // Safe to do because we checked for INT64_MIN earlier + + do { + *(p--) = '0' + (n % 10); + n /= 10; + } while (n > 0); + + if (negative) *(p--) = '-'; + + return Text$from_strn(p + 1, (size_t)(&buf[19] - p)); +} + +public +Text_t Int$value_as_text(Int_t i) { + if (likely(i.small & 1L)) { + return _int64_to_text(i.small >> 2L); + } else { + char *str = mpz_get_str(NULL, 10, *i.big); + return Text$from_str(str); + } +} + +public +Text_t Int$as_text(const void *i, bool colorize, const TypeInfo_t *info) { + (void)info; + if (!i) return Text("Int"); + Text_t text = Int$value_as_text(*(Int_t *)i); + if (colorize) text = Text$concat(Text("\x1b[35m"), text, Text("\x1b[m")); + return text; +} + +static bool Int$is_none(const void *i, const TypeInfo_t *info) { + (void)info; + return ((Int_t *)i)->small == 0L; +} + +public +PUREFUNC int32_t Int$compare_value(const Int_t x, const Int_t y) { + if (likely(x.small & y.small & 1L)) return (x.small > y.small) - (x.small < y.small); + else if (x.small & 1) return -mpz_cmp_si(*y.big, x.small); + else if (y.small & 1) return mpz_cmp_si(*x.big, y.small); + else return x.big == y.big ? 0 : mpz_cmp(*x.big, *y.big); +} + +public +PUREFUNC int32_t Int$compare(const void *x, const void *y, const TypeInfo_t *info) { + (void)info; + return Int$compare_value(*(Int_t *)x, *(Int_t *)y); +} + +public +PUREFUNC bool Int$equal_value(const Int_t x, const Int_t y) { + if (likely((x.small | y.small) & 1L)) return x.small == y.small; + else return x.big == y.big ? 0 : (mpz_cmp(*x.big, *y.big) == 0); +} + +public +PUREFUNC bool Int$equal(const void *x, const void *y, const TypeInfo_t *info) { + (void)info; + return Int$equal_value(*(Int_t *)x, *(Int_t *)y); +} + +public +CONSTFUNC Int_t Int$clamped(Int_t x, Int_t low, Int_t high) { + return (Int$compare(&x, &low, &Int$info) <= 0) ? low : (Int$compare(&x, &high, &Int$info) >= 0 ? high : x); +} + +public +CONSTFUNC bool Int$is_between(const Int_t x, const Int_t low, const Int_t high) { + return Int$compare_value(low, x) <= 0 && Int$compare_value(x, high) <= 0; +} + +public +PUREFUNC uint64_t Int$hash(const void *vx, const TypeInfo_t *info) { + (void)info; + Int_t *x = (Int_t *)vx; + if (likely(x->small & 1L)) { + return siphash24((void *)x, sizeof(Int_t)); + } else { + char *str = mpz_get_str(NULL, 16, *x->big); + return siphash24((void *)str, strlen(str)); + } +} + +public +Text_t Int$hex(Int_t i, Int_t digits_int, bool uppercase, bool prefix) { + if (Int$is_negative(i)) return Text$concat(Text("-"), Int$hex(Int$negative(i), digits_int, uppercase, prefix)); + + if (likely(i.small & 1L)) { + uint64_t u64 = (uint64_t)(i.small >> 2); + return Text$from_str(String( + hex(u64, .no_prefix = !prefix, .digits = Int32$from_int(digits_int, false), .uppercase = uppercase))); + } else { + char *str = mpz_get_str(NULL, 16, *i.big); + if (uppercase) { + for (char *c = str; *c; c++) + *c = (char)toupper(*c); + } + int64_t digits = Int64$from_int(digits_int, false); + int64_t needed_zeroes = digits - (int64_t)strlen(str); + if (needed_zeroes <= 0) return prefix ? Text$concat(Text("0x"), Text$from_str(str)) : Text$from_str(str); + + char *zeroes = GC_MALLOC_ATOMIC((size_t)(needed_zeroes)); + memset(zeroes, '0', (size_t)(needed_zeroes)); + if (prefix) return Text$concat(Text("0x"), Text$from_str(zeroes), Text$from_str(str)); + else return Text$concat(Text$from_str(zeroes), Text$from_str(str)); + } +} + +public +Text_t Int$octal(Int_t i, Int_t digits_int, bool prefix) { + if (Int$is_negative(i)) return Text$concat(Text("-"), Int$octal(Int$negative(i), digits_int, prefix)); + + if (likely(i.small & 1L)) { + uint64_t u64 = (uint64_t)(i.small >> 2); + return Text$from_str(String(oct(u64, .no_prefix = !prefix, .digits = Int32$from_int(digits_int, false)))); + } else { + int64_t digits = Int64$from_int(digits_int, false); + char *str = mpz_get_str(NULL, 8, *i.big); + int64_t needed_zeroes = digits - (int64_t)strlen(str); + if (needed_zeroes <= 0) return prefix ? Text$concat(Text("0o"), Text$from_str(str)) : Text$from_str(str); + + char *zeroes = GC_MALLOC_ATOMIC((size_t)(needed_zeroes)); + memset(zeroes, '0', (size_t)(needed_zeroes)); + if (prefix) return Text$concat(Text("0o"), Text$from_str(zeroes), Text$from_str(str)); + else return Text$concat(Text$from_str(zeroes), Text$from_str(str)); + } +} + +public +Int_t Int$slow_plus(Int_t x, Int_t y) { + mpz_t result; + mpz_init_set_int(result, x); + if (y.small & 1L) { + if (y.small < 0L) mpz_sub_ui(result, result, (uint64_t)(-(y.small >> 2L))); + else mpz_add_ui(result, result, (uint64_t)(y.small >> 2L)); + } else { + mpz_add(result, result, *y.big); + } + return Int$from_mpz(result); +} + +public +Int_t Int$slow_minus(Int_t x, Int_t y) { + mpz_t result; + mpz_init_set_int(result, x); + if (y.small & 1L) { + if (y.small < 0L) mpz_add_ui(result, result, (uint64_t)(-(y.small >> 2L))); + else mpz_sub_ui(result, result, (uint64_t)(y.small >> 2L)); + } else { + mpz_sub(result, result, *y.big); + } + return Int$from_mpz(result); +} + +public +Int_t Int$slow_times(Int_t x, Int_t y) { + mpz_t result; + mpz_init_set_int(result, x); + if (y.small & 1L) mpz_mul_si(result, result, y.small >> 2L); + else mpz_mul(result, result, *y.big); + return Int$from_mpz(result); +} + +public +Int_t Int$slow_divided_by(Int_t dividend, Int_t divisor) { + // Euclidean division, see: + // https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/divmodnote-letter.pdf + mpz_t quotient, remainder; + mpz_init_set_int(quotient, dividend); + mpz_init_set_int(remainder, divisor); + mpz_tdiv_qr(quotient, remainder, quotient, remainder); + if (mpz_sgn(remainder) < 0) { + bool d_positive = likely(divisor.small & 1L) ? divisor.small > 0x1L : mpz_sgn(*divisor.big) > 0; + if (d_positive) mpz_sub_ui(quotient, quotient, 1); + else mpz_add_ui(quotient, quotient, 1); + } + return Int$from_mpz(quotient); +} + +public +Int_t Int$slow_modulo(Int_t x, Int_t modulus) { + mpz_t result; + mpz_init_set_int(result, x); + mpz_t divisor; + mpz_init_set_int(divisor, modulus); + mpz_mod(result, result, divisor); + return Int$from_mpz(result); +} + +public +Int_t Int$slow_modulo1(Int_t x, Int_t modulus) { + mpz_t result; + mpz_init_set_int(result, x); + mpz_sub_ui(result, result, 1); + mpz_t divisor; + mpz_init_set_int(divisor, modulus); + mpz_mod(result, result, divisor); + mpz_add_ui(result, result, 1); + return Int$from_mpz(result); +} + +public +Int_t Int$slow_left_shifted(Int_t x, Int_t y) { + mp_bitcnt_t bits = (mp_bitcnt_t)Int64$from_int(y, false); + mpz_t result; + mpz_init_set_int(result, x); + mpz_mul_2exp(result, result, bits); + return Int$from_mpz(result); +} + +public +Int_t Int$slow_right_shifted(Int_t x, Int_t y) { + mp_bitcnt_t bits = (mp_bitcnt_t)Int64$from_int(y, false); + mpz_t result; + mpz_init_set_int(result, x); + mpz_tdiv_q_2exp(result, result, bits); + return Int$from_mpz(result); +} + +public +Int_t Int$slow_bit_and(Int_t x, Int_t y) { + mpz_t result; + mpz_init_set_int(result, x); + mpz_t y_mpz; + mpz_init_set_int(y_mpz, y); + mpz_and(result, result, y_mpz); + return Int$from_mpz(result); +} + +public +Int_t Int$slow_bit_or(Int_t x, Int_t y) { + mpz_t result; + mpz_init_set_int(result, x); + mpz_t y_mpz; + mpz_init_set_int(y_mpz, y); + mpz_ior(result, result, y_mpz); + return Int$from_mpz(result); +} + +public +Int_t Int$slow_bit_xor(Int_t x, Int_t y) { + mpz_t result; + mpz_init_set_int(result, x); + mpz_t y_mpz; + mpz_init_set_int(y_mpz, y); + mpz_xor(result, result, y_mpz); + return Int$from_mpz(result); +} + +public +Int_t Int$slow_negated(Int_t x) { + mpz_t result; + mpz_init_set_int(result, x); + mpz_neg(result, result); + mpz_sub_ui(result, result, 1); + return Int$from_mpz(result); +} + +public +Int_t Int$slow_negative(Int_t x) { + if (likely(x.small & 1L)) return (Int_t){.small = 4L * -((x.small) >> 2L) + 1L}; + + mpz_t result; + mpz_init_set_int(result, x); + mpz_neg(result, result); + return Int$from_mpz(result); +} + +public +Int_t Int$abs(Int_t x) { + if (likely(x.small & 1L)) return (Int_t){.small = 4L * labs((x.small) >> 2L) + 1L}; + + mpz_t result; + mpz_init_set_int(result, x); + mpz_abs(result, result); + return Int$from_mpz(result); +} + +public +Int_t Int$power(Int_t base, Int_t exponent) { + int64_t exp = Int64$from_int(exponent, false); + if (unlikely(exp < 0)) fail("Cannot take a negative power of an integer!"); + mpz_t result; + mpz_init_set_int(result, base); + mpz_pow_ui(result, result, (uint64_t)exp); + return Int$from_mpz(result); +} + +public +Int_t Int$gcd(Int_t x, Int_t y) { + if (likely(x.small & y.small & 0x1L)) return I_small(Int32$gcd(x.small >> 2L, y.small >> 2L)); + + mpz_t result; + mpz_init(result); + if (x.small & 0x1L) mpz_gcd_ui(result, *y.big, (uint64_t)labs(x.small >> 2L)); + else if (y.small & 0x1L) mpz_gcd_ui(result, *x.big, (uint64_t)labs(y.small >> 2L)); + else mpz_gcd(result, *x.big, *y.big); + return Int$from_mpz(result); +} + +public +OptionalInt_t Int$sqrt(Int_t i) { + if (Int$compare_value(i, I(0)) < 0) return NONE_INT; + mpz_t result; + mpz_init_set_int(result, i); + mpz_sqrt(result, result); + return Int$from_mpz(result); +} + +public +bool Int$get_bit(Int_t x, Int_t bit_index) { + mpz_t i; + mpz_init_set_int(i, x); + if (Int$compare_value(bit_index, I(1)) < 0) fail("Invalid bit index (expected 1 or higher): ", bit_index); + if (Int$compare_value(bit_index, Int$from_int64(INT64_MAX)) > 0) fail("Bit index is too large! ", bit_index); + + int is_bit_set = mpz_tstbit(i, (mp_bitcnt_t)(Int64$from_int(bit_index, true) - 1)); + return (bool)is_bit_set; +} + +typedef struct { + OptionalInt_t current, last; + Int_t step; +} IntRange_t; + +static OptionalInt_t _next_int(IntRange_t *info) { + OptionalInt_t i = info->current; + if (!Int$is_none(&i, &Int$info)) { + Int_t next = Int$plus(i, info->step); + if (!Int$is_none(&info->last, &Int$info) + && Int$compare_value(next, info->last) == Int$compare_value(info->step, I(0))) + next = NONE_INT; + info->current = next; + } + return i; +} + +public +PUREFUNC Closure_t Int$to(Int_t first, Int_t last, OptionalInt_t step) { + IntRange_t *range = GC_MALLOC(sizeof(IntRange_t)); + range->current = first; + range->last = last; + range->step = Int$is_none(&step, &Int$info) ? Int$compare_value(last, first) >= 0 + ? (Int_t){.small = (1L << 2L) | 1L} + : (Int_t){.small = (-1L >> 2L) | 1L} + : step; + return (Closure_t){.fn = _next_int, .userdata = range}; +} + +public +PUREFUNC Closure_t Int$onward(Int_t first, Int_t step) { + IntRange_t *range = GC_MALLOC(sizeof(IntRange_t)); + range->current = first; + range->last = NONE_INT; + range->step = step; + return (Closure_t){.fn = _next_int, .userdata = range}; +} + +public +Int_t Int$from_str(const char *str) { + mpz_t i; + int result; + if (strncmp(str, "0x", 2) == 0) { + result = mpz_init_set_str(i, str + 2, 16); + } else if (strncmp(str, "0o", 2) == 0) { + result = mpz_init_set_str(i, str + 2, 8); + } else if (strncmp(str, "0b", 2) == 0) { + result = mpz_init_set_str(i, str + 2, 2); + } else { + result = mpz_init_set_str(i, str, 10); + } + if (result != 0) return NONE_INT; + return Int$from_mpz(i); +} + +public +OptionalInt_t Int$parse(Text_t text, Text_t *remainder) { + const char *str = Text$as_c_string(text); + mpz_t i; + int result; + if (strncmp(str, "0x", 2) == 0) { + const char *end = str + 2 + strspn(str + 2, "0123456789abcdefABCDEF"); + if (remainder) *remainder = Text$from_str(end); + else if (*end != '\0') return NONE_INT; + result = mpz_init_set_str(i, str + 2, 16); + } else if (strncmp(str, "0o", 2) == 0) { + const char *end = str + 2 + strspn(str + 2, "01234567"); + if (remainder) *remainder = Text$from_str(end); + else if (*end != '\0') return NONE_INT; + result = mpz_init_set_str(i, str + 2, 8); + } else if (strncmp(str, "0b", 2) == 0) { + const char *end = str + 2 + strspn(str + 2, "01"); + if (remainder) *remainder = Text$from_str(end); + else if (*end != '\0') return NONE_INT; + result = mpz_init_set_str(i, str + 2, 2); + } else { + const char *end = str + strspn(str, "0123456789"); + if (remainder) *remainder = Text$from_str(end); + else if (*end != '\0') return NONE_INT; + result = mpz_init_set_str(i, str, 10); + } + if (result != 0) { + if (remainder) *remainder = text; + return NONE_INT; + } + return Int$from_mpz(i); +} + +public +bool Int$is_prime(Int_t x, Int_t reps) { + mpz_t p; + mpz_init_set_int(p, x); + if (unlikely(Int$compare_value(reps, I(9999)) > 0)) + fail("Number of prime-test repetitions should not be above 9999"); + int reps_int = Int32$from_int(reps, false); + return (mpz_probab_prime_p(p, reps_int) != 0); +} + +public +Int_t Int$next_prime(Int_t x) { + mpz_t p; + mpz_init_set_int(p, x); + mpz_nextprime(p, p); + return Int$from_mpz(p); +} + +#if __GNU_MP_VERSION >= 6 +#if __GNU_MP_VERSION_MINOR >= 3 +public +OptionalInt_t Int$prev_prime(Int_t x) { + mpz_t p; + mpz_init_set_int(p, x); + if (unlikely(mpz_prevprime(p, p) == 0)) return NONE_INT; + return Int$from_mpz(p); +} +#endif +#endif + +public +Int_t Int$choose(Int_t n, Int_t k) { + if unlikely (Int$compare_value(n, I_small(0)) < 0) fail("Negative inputs are not supported for choose()"); + + mpz_t ret; + mpz_init(ret); + + int64_t k_i64 = Int64$from_int(k, false); + if unlikely (k_i64 < 0) fail("Negative inputs are not supported for choose()"); + + if likely (n.small & 1L) { + mpz_bin_uiui(ret, (unsigned long)(n.small >> 2L), (unsigned long)k_i64); + } else { + mpz_t n_mpz; + mpz_init_set_int(n_mpz, n); + mpz_bin_ui(ret, n_mpz, (unsigned long)k_i64); + } + return Int$from_mpz(ret); +} + +public +Int_t Int$factorial(Int_t n) { + mpz_t ret; + mpz_init(ret); + int64_t n_i64 = Int64$from_int(n, false); + if unlikely (n_i64 < 0) fail("Factorials are not defined for negative numbers"); + mpz_fac_ui(ret, (unsigned long)n_i64); + return Int$from_mpz(ret); +} + +static void Int$serialize(const void *obj, FILE *out, Table_t *pointers, const TypeInfo_t *info) { + (void)info; + Int_t i = *(Int_t *)obj; + if (likely(i.small & 1L)) { + fputc(0, out); + int64_t i64 = i.small >> 2L; + Int64$serialize(&i64, out, pointers, &Int64$info); + } else { + fputc(1, out); + mpz_t n; + mpz_init_set_int(n, *(Int_t *)obj); + mpz_out_raw(out, n); + } +} + +static void Int$deserialize(FILE *in, void *obj, List_t *pointers, const TypeInfo_t *info) { + (void)info; + if (fgetc(in) == 0) { + int64_t i = 0; + Int64$deserialize(in, &i, pointers, &Int64$info); + *(Int_t *)obj = (Int_t){.small = (i << 2L) | 1L}; + } else { + mpz_t n; + mpz_init(n); + mpz_inp_raw(n, in); + *(Int_t *)obj = Int$from_mpz(n); + } +} + +public +const TypeInfo_t Int$info = { + .size = sizeof(Int_t), + .align = __alignof__(Int_t), + .metamethods = + { + .compare = Int$compare, + .equal = Int$equal, + .hash = Int$hash, + .as_text = Int$as_text, + .is_none = Int$is_none, + .serialize = Int$serialize, + .deserialize = Int$deserialize, + }, +}; |
