diff options
| -rw-r--r-- | compile.c | 18 | ||||
| -rw-r--r-- | stdlib/arrays.c | 35 | ||||
| -rw-r--r-- | stdlib/arrays.h | 8 | ||||
| -rw-r--r-- | stdlib/integers.c | 12 | ||||
| -rw-r--r-- | stdlib/integers.h | 1 | ||||
| -rw-r--r-- | test/arrays.tm | 25 |
6 files changed, 74 insertions, 25 deletions
@@ -2688,6 +2688,12 @@ CORD compile(env_t *env, ast_t *ast) case ArrayType: { type_t *item_t = Match(self_value_t, ArrayType)->item_type; CORD padded_item_size = CORD_asprintf("%ld", type_size(item_t)); + + type_t *rng_fn = Type(ClosureType, .fn=Type(FunctionType, .args=NULL, .ret=Type(IntType, .bits=TYPE_IBITS64))); + ast_t *default_rng = FakeAST(InlineCCode, + .code=CORD_all("((Closure_t){.fn=Int64$full_random})"), + .type=rng_fn); + if (streq(call->name, "insert")) { CORD self = compile_to_pointer_depth(env, call->self, 1, false); arg_t *arg_spec = new(arg_t, .name="item", .type=item_t, @@ -2714,8 +2720,8 @@ CORD compile(env_t *env, ast_t *ast) compile_type_info(env, self_value_t), ")"); } else if (streq(call->name, "random")) { CORD self = compile_to_pointer_depth(env, call->self, 0, false); - (void)compile_arguments(env, ast, NULL, call->args); - return CORD_all("Array$random_value(", self, ", ", compile_type(item_t), ")"); + arg_t *arg_spec = new(arg_t, .name="rng", .type=rng_fn, .default_val=default_rng); + return CORD_all("Array$random_value(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ", compile_type(item_t), ")"); } else if (streq(call->name, "has")) { CORD self = compile_to_pointer_depth(env, call->self, 0, false); arg_t *arg_spec = new(arg_t, .name="item", .type=item_t); @@ -2731,12 +2737,12 @@ CORD compile(env_t *env, ast_t *ast) padded_item_size, ")"); } else if (streq(call->name, "shuffle")) { CORD self = compile_to_pointer_depth(env, call->self, 1, false); - (void)compile_arguments(env, ast, NULL, call->args); - return CORD_all("Array$shuffle(", self, ", ", padded_item_size, ")"); + arg_t *arg_spec = new(arg_t, .name="rng", .type=rng_fn, .default_val=default_rng); + return CORD_all("Array$shuffle(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ", padded_item_size, ")"); } else if (streq(call->name, "shuffled")) { CORD self = compile_to_pointer_depth(env, call->self, 0, false); - (void)compile_arguments(env, ast, NULL, call->args); - return CORD_all("Array$shuffled(", self, ", ", padded_item_size, ")"); + arg_t *arg_spec = new(arg_t, .name="rng", .type=rng_fn, .default_val=default_rng); + return CORD_all("Array$shuffled(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ", padded_item_size, ")"); } else if (streq(call->name, "sort") || streq(call->name, "sorted")) { CORD self = streq(call->name, "sort") ? compile_to_pointer_depth(env, call->self, 1, false) diff --git a/stdlib/arrays.c b/stdlib/arrays.c index 6b61f5b5..552fb4cb 100644 --- a/stdlib/arrays.c +++ b/stdlib/arrays.c @@ -249,36 +249,51 @@ public Array_t Array$sorted(Array_t arr, Closure_t comparison, int64_t padded_it return arr; } -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstack-protector" -public void Array$shuffle(Array_t *arr, int64_t padded_item_size) +static uint64_t random_range(Closure_t rng, uint64_t upper_bound) +{ + if (upper_bound < 2) + return 0; + + // This approach is taken from arc4random_uniform() + uint64_t min = -upper_bound % upper_bound; + uint64_t r; + for (;;) { + r = ((uint64_t(*)(void*))rng.fn)(rng.userdata); + if (r >= min) + break; + } + + return r % upper_bound; +} + +public void Array$shuffle(Array_t *arr, Closure_t rng, int64_t padded_item_size) { if (arr->data_refcount != 0 || (int64_t)arr->stride != padded_item_size) Array$compact(arr, padded_item_size); char tmp[padded_item_size]; for (int64_t i = arr->length-1; i > 1; i--) { - int64_t j = arc4random_uniform(i+1); + int64_t j = (int64_t)random_range(rng, (uint64_t)(i+1)); memcpy(tmp, arr->data + i*padded_item_size, (size_t)padded_item_size); memcpy((void*)arr->data + i*padded_item_size, arr->data + j*padded_item_size, (size_t)padded_item_size); memcpy((void*)arr->data + j*padded_item_size, tmp, (size_t)padded_item_size); } } -#pragma GCC diagnostic pop -public Array_t Array$shuffled(Array_t arr, int64_t padded_item_size) +public Array_t Array$shuffled(Array_t arr, Closure_t rng, int64_t padded_item_size) { Array$compact(&arr, padded_item_size); - Array$shuffle(&arr, padded_item_size); + Array$shuffle(&arr, rng, padded_item_size); return arr; } -public void *Array$random(Array_t arr) +public void *Array$random(Array_t arr, Closure_t rng) { if (arr.length == 0) return NULL; // fail("Cannot get a random item from an empty array!"); - int64_t index = arc4random_uniform(arr.length); - return arr.data + arr.stride*index; + + uint64_t index = random_range(rng, (uint64_t)arr.length); + return arr.data + arr.stride*(int64_t)index; } public Table_t Array$counts(Array_t arr, const TypeInfo_t *type) diff --git a/stdlib/arrays.h b/stdlib/arrays.h index 03f00d49..5d452d30 100644 --- a/stdlib/arrays.h +++ b/stdlib/arrays.h @@ -70,10 +70,10 @@ Int_t Array$find(Array_t arr, void *item, const TypeInfo_t *type); Int_t Array$first(Array_t arr, Closure_t predicate); void Array$sort(Array_t *arr, Closure_t comparison, int64_t padded_item_size); Array_t Array$sorted(Array_t arr, Closure_t comparison, int64_t padded_item_size); -void Array$shuffle(Array_t *arr, int64_t padded_item_size); -Array_t Array$shuffled(Array_t arr, int64_t padded_item_size); -void *Array$random(Array_t arr); -#define Array$random_value(arr, t) ({ Array_t _arr = arr; if (_arr.length == 0) fail("Cannot get a random value from an empty array!"); *(t*)Array$random(_arr); }) +void Array$shuffle(Array_t *arr, Closure_t rng, int64_t padded_item_size); +Array_t Array$shuffled(Array_t arr, Closure_t rng, int64_t padded_item_size); +void *Array$random(Array_t arr, Closure_t rng); +#define Array$random_value(arr, rng, t) ({ Array_t _arr = arr; if (_arr.length == 0) fail("Cannot get a random value from an empty array!"); *(t*)Array$random(_arr, rng); }) Array_t Array$sample(Array_t arr, Int_t n, Array_t weights, int64_t padded_item_size); Table_t Array$counts(Array_t arr, const TypeInfo_t *type); void Array$clear(Array_t *array); diff --git a/stdlib/integers.c b/stdlib/integers.c index f604aa53..8d305daf 100644 --- a/stdlib/integers.c +++ b/stdlib/integers.c @@ -440,14 +440,16 @@ public const TypeInfo_t Int$info = { } \ return bit_array; \ } \ + public c_type KindOfInt ## $full_random(void) { \ + c_type r; \ + arc4random_buf(&r, sizeof(r)); \ + return r; \ + } \ public c_type KindOfInt ## $random(c_type min, c_type max) { \ if (min > max) fail("Random minimum value (%ld) is larger than the maximum value (%ld)", min, max); \ if (min == max) return min; \ - if (min == min_val && max == max_val) { \ - c_type r; \ - arc4random_buf(&r, sizeof(r)); \ - return r; \ - } \ + if (min == min_val && max == max_val) \ + return KindOfInt ## $full_random(); \ uint64_t range = (uint64_t)max - (uint64_t)min + 1; \ uint64_t min_r = -range % range; \ uint64_t r; \ diff --git a/stdlib/integers.h b/stdlib/integers.h index e7b5b0e1..04699162 100644 --- a/stdlib/integers.h +++ b/stdlib/integers.h @@ -35,6 +35,7 @@ Text_t type_name ## $octal(c_type i, Int_t digits, bool prefix); \ Array_t type_name ## $bits(c_type x); \ c_type type_name ## $random(c_type min, c_type max); \ + c_type type_name ## $full_random(void); \ to_attr Range_t type_name ## $to(c_type from, c_type to); \ PUREFUNC Optional ## type_name ## _t type_name ## $from_text(Text_t text); \ MACROLIKE PUREFUNC c_type type_name ## $clamped(c_type x, c_type min, c_type max) { \ diff --git a/test/arrays.tm b/test/arrays.tm index e1c31cf6..25d45582 100644 --- a/test/arrays.tm +++ b/test/arrays.tm @@ -173,3 +173,28 @@ func main(): >> [4, 5, 6]:first(func(i:&Int): i:is_prime()) = 2? + test_seeded_rng() + +# Inspired by: https://nullprogram.com/blog/2017/09/21/ +struct RNGxoroshiro128(s0=Int64.random(),s1=Int64.random()): + func int_fn(rng:@RNGxoroshiro128 -> func(->Int64)): + return func(-> Int64): + s0 := rng.s0 + s1 := rng.s1 + result := s0:wrapping_plus(s1) + s1 xor= s0 + rng.s0 = (s0 <<< 55) or (s0 >>> 9) xor s1 xor (s1 <<< 14) + rng.s1 = (s1 <<< 36) or (s1 >>> 28) + return result + +func test_seeded_rng(): + !! Seeded RNG: + rng_state := RNGxoroshiro128(Int64.random(), Int64.random()) + rng1 := @rng_state:int_fn() + rng2 := @rng_state:int_fn() + + nums := [i*10 for i in 20] + >> nums:random(rng1) == nums:random(rng2) + = yes + >> nums:shuffled(rng1) == nums:shuffled(rng2) + = yes |
