diff options
| author | Bruce Hill <bruce@bruce-hill.com> | 2025-04-01 16:55:24 -0400 |
|---|---|---|
| committer | Bruce Hill <bruce@bruce-hill.com> | 2025-04-01 16:55:24 -0400 |
| commit | 6de2d68a700a137bbe55668e036c62280ece8bb5 (patch) | |
| tree | eb1e3cee37cd9b2f1458b9ceff0141bfbd7a91a9 /examples/random/random.tm | |
| parent | a32c3747568562251d6c390faf325bf3ed3946e6 (diff) | |
Moved RNG out of the compiler and into a standalone library
Diffstat (limited to 'examples/random/random.tm')
| -rw-r--r-- | examples/random/random.tm | 240 |
1 files changed, 240 insertions, 0 deletions
diff --git a/examples/random/random.tm b/examples/random/random.tm new file mode 100644 index 00000000..7db02016 --- /dev/null +++ b/examples/random/random.tm @@ -0,0 +1,240 @@ +# Random Number Generator (RNG) implementation based on ChaCha + +use ./sysrandom.h +use ./chacha.h + +struct chacha_ctx(j0,j1,j2,j3,j4,j5,j6,j7,j8,j9,j10,j11,j12,j13,j14,j15:Int32; extern, secret): + func from_seed(seed=[:Byte] -> chacha_ctx): + return inline C : chacha_ctx { + chacha_ctx ctx; + uint8_t seed_bytes[KEYSZ + IVSZ] = {}; + for (int64_t i = 0; i < (int64_t)sizeof(seed_bytes); i++) + seed_bytes[i] = i < _$seed.length ? *(uint8_t*)(_$seed.data + i*_$seed.stride) : 0; + chacha_keysetup(&ctx, seed_bytes); + chacha_ivsetup(&ctx, seed_bytes + KEYSZ); + ctx; + } + +default_random := RandomNumberGenerator.new() + +func _os_random_bytes(count:Int64 -> [Byte]): + return inline C : [Byte] { + uint8_t *random_bytes = GC_MALLOC_ATOMIC(_$count); + getrandom(random_bytes, _$count, 0); + (Array_t){.length=_$count, .data=random_bytes, .stride=1, .atomic=1}; + } + +struct RandomNumberGenerator(_chacha:chacha_ctx, _random_bytes=[:Byte]; secret): + func new(seed=none:[Byte], -> @RandomNumberGenerator): + ctx := chacha_ctx.from_seed(seed or _os_random_bytes(40)) + return @RandomNumberGenerator(ctx, [:Byte]) + + func _rekey(rng:&RandomNumberGenerator): + rng._random_bytes = inline C : [Byte] { + Byte_t new_keystream[KEYSZ + IVSZ] = {}; + // Fill the buffer with the keystream + chacha_encrypt_bytes(&_$rng->_chacha, new_keystream, new_keystream, sizeof(new_keystream)); + // Immediately reinitialize for backtracking resistance + chacha_keysetup(&_$rng->_chacha, new_keystream); + chacha_ivsetup(&_$rng->_chacha, new_keystream + KEYSZ); + Array_t new_bytes = (Array_t){.data=GC_MALLOC_ATOMIC(1024), .length=1024, .stride=1, .atomic=1}; + memset(new_bytes.data, 0, new_bytes.length); + chacha_encrypt_bytes(&_$rng->_chacha, new_bytes.data, new_bytes.data, new_bytes.length); + new_bytes; + } + + func _fill_bytes(rng:&RandomNumberGenerator, dest:&Memory, needed:Int64): + inline C { + while (_$needed > 0) { + if (_$rng->_random_bytes.length == 0) + _$random$RandomNumberGenerator$_rekey(_$rng); + + assert(_$rng->_random_bytes.stride == 1); + + int64_t batch_size = MIN(_$needed, _$rng->_random_bytes.length); + uint8_t *batch_src = _$rng->_random_bytes.data; + memcpy(_$dest, batch_src, batch_size); + memset(batch_src, 0, batch_size); + _$rng->_random_bytes.data += batch_size; + _$rng->_random_bytes.length -= batch_size; + _$dest += batch_size; + _$needed -= batch_size; + } + } + + func bytes(rng:&RandomNumberGenerator, count:Int -> [Byte]): + return inline C : [Byte] { + int64_t count64 = Int64$from_int(_$count, false); + Array_t ret = {.data=GC_MALLOC_ATOMIC(count64), .stride=1, .atomic=1, .length=count64}; + _$random$RandomNumberGenerator$_fill_bytes(_$rng, ret.data, count64); + ret; + } + + func byte(rng:&RandomNumberGenerator -> Byte): + return inline C : Byte { + Byte_t b; + _$random$RandomNumberGenerator$_fill_bytes(_$rng, &b, sizeof(b)); + b; + } + + func bool(rng:&RandomNumberGenerator, probability=0.5 -> Bool): + if probability == 0.5: + return rng:byte() < 0x80 + else: + return rng:num(0., 1.) < 0.5 + + func int64(rng:&RandomNumberGenerator, min=Int64.min, max=Int64.max -> Int64): + fail("Random minimum value $min is larger than the maximum value $max") if min > max + return min if min == max + if min == Int64.min and max == Int64.max: + return inline C : Int64 { + int64_t i; + _$random$RandomNumberGenerator$_fill_bytes(_$rng, &i, sizeof(i)); + i; + } + + return inline C : Int64 { + uint64_t range = (uint64_t)_$max - (uint64_t)_$min + 1; + uint64_t min_r = -range % range; + uint64_t r; + for (;;) { + _$random$RandomNumberGenerator$_fill_bytes(_$rng, (uint8_t*)&r, sizeof(r)); + if (r >= min_r) break; + } + (int64_t)((uint64_t)_$min + (r % range)); + } + + func int32(rng:&RandomNumberGenerator, min=Int32.min, max=Int32.max -> Int32): + fail("Random minimum value $min is larger than the maximum value $max") if min > max + return min if min == max + if min == Int32.min and max == Int32.max: + return inline C : Int32 { + int32_t i; + _$random$RandomNumberGenerator$_fill_bytes(_$rng, &i, sizeof(i)); + i; + } + + return inline C : Int32 { + uint32_t range = (uint32_t)_$max - (uint32_t)_$min + 1; + uint32_t min_r = -range % range; + uint32_t r; + for (;;) { + _$random$RandomNumberGenerator$_fill_bytes(_$rng, (uint8_t*)&r, sizeof(r)); + if (r >= min_r) break; + } + (int32_t)((uint32_t)_$min + (r % range)); + } + + func int16(rng:&RandomNumberGenerator, min=Int16.min, max=Int16.max -> Int16): + fail("Random minimum value $min is larger than the maximum value $max") if min > max + return min if min == max + if min == Int16.min and max == Int16.max: + return inline C : Int16 { + int16_t i; + _$random$RandomNumberGenerator$_fill_bytes(_$rng, &i, sizeof(i)); + i; + } + + return inline C : Int16 { + uint16_t range = (uint16_t)_$max - (uint16_t)_$min + 1; + uint16_t min_r = -range % range; + uint16_t r; + for (;;) { + _$random$RandomNumberGenerator$_fill_bytes(_$rng, (uint8_t*)&r, sizeof(r)); + if (r >= min_r) break; + } + (int16_t)((uint16_t)_$min + (r % range)); + } + + func int8(rng:&RandomNumberGenerator, min=Int8.min, max=Int8.max -> Int8): + fail("Random minimum value $min is larger than the maximum value $max") if min > max + return min if min == max + if min == Int8.min and max == Int8.max: + return inline C : Int8 { + int8_t i; + _$random$RandomNumberGenerator$_fill_bytes(_$rng, &i, sizeof(i)); + i; + } + + return inline C : Int8 { + uint8_t range = (uint8_t)_$max - (uint8_t)_$min + 1; + uint8_t min_r = -range % range; + uint8_t r; + for (;;) { + _$random$RandomNumberGenerator$_fill_bytes(_$rng, (uint8_t*)&r, sizeof(r)); + if (r >= min_r) break; + } + (int8_t)((uint8_t)_$min + (r % range)); + } + + func num(rng:&RandomNumberGenerator, min=0., max=1. -> Num): + return inline C : Num { + if (_$min > _$max) fail("Random minimum value (", _$min, ") is larger than the maximum value (", _$max, ")"); + if (_$min == _$max) return _$min; + + union { + Num_t num; + uint64_t bits; + } r = {.bits=0}, one = {.num=1.0}; + _$random$RandomNumberGenerator$_fill_bytes(_$rng, (uint8_t*)&r, sizeof(r)); + + // Set r.num to 1.<random-bits> + r.bits &= ~(0xFFFULL << 52); + r.bits |= (one.bits & (0xFFFULL << 52)); + + r.num -= 1.0; + + (_$min == 0.0 && _$max == 1.0) ? r.num : ((1.0-r.num)*_$min + r.num*_$max); + } + + func num32(rng:&RandomNumberGenerator, min=Num32(0.), max=Num32(1.) -> Num32): + return Num32(rng:num(Num(min), Num(max))) + + func int(rng:&RandomNumberGenerator, min:Int, max:Int -> Int): + return inline C : Int { + if (likely(((_$min.small & _$max.small) & 1) != 0)) { + int32_t r = _$random$RandomNumberGenerator$int32(_$rng, (int32_t)(_$min.small >> 2), (int32_t)(_$max.small >> 2)); + return I_small(r); + } + + int32_t cmp = Int$compare_value(_$min, _$max); + if (cmp > 0) + fail("Random minimum value (", _$min, ") is larger than the maximum value (", _$max, ")"); + if (cmp == 0) return _$min; + + mpz_t range_size; + mpz_init_set_int(range_size, _$max); + if (_$min.small & 1) { + mpz_t min_mpz; + mpz_init_set_si(min_mpz, _$min.small >> 2); + mpz_sub(range_size, range_size, min_mpz); + } else { + mpz_sub(range_size, range_size, *_$min.big); + } + + gmp_randstate_t gmp_rng; + gmp_randinit_default(gmp_rng); + int64_t seed = _$random$RandomNumberGenerator$int64(_$rng, INT64_MIN, INT64_MAX); + gmp_randseed_ui(gmp_rng, (unsigned long)seed); + + mpz_t r; + mpz_init(r); + mpz_urandomm(r, gmp_rng, range_size); + + gmp_randclear(gmp_rng); + Int$plus(_$min, Int$from_mpz(r)); + } + + +func main(): + >> rng := RandomNumberGenerator.new() + >> rng:num() + >> rng:num() + >> rng:num() + >> rng:num(0, 100) + >> rng:byte() + >> rng:bytes(20) + # >> rng:int(1, 100) + # >> rng:int(1, 100) + # >> rng:int(1, 100) + # >> rng:int(1, 100) |
