aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compile.c18
-rw-r--r--stdlib/arrays.c35
-rw-r--r--stdlib/arrays.h8
-rw-r--r--stdlib/integers.c12
-rw-r--r--stdlib/integers.h1
-rw-r--r--test/arrays.tm25
6 files changed, 74 insertions, 25 deletions
diff --git a/compile.c b/compile.c
index 1f67156a..4c7c1bd5 100644
--- a/compile.c
+++ b/compile.c
@@ -2688,6 +2688,12 @@ CORD compile(env_t *env, ast_t *ast)
case ArrayType: {
type_t *item_t = Match(self_value_t, ArrayType)->item_type;
CORD padded_item_size = CORD_asprintf("%ld", type_size(item_t));
+
+ type_t *rng_fn = Type(ClosureType, .fn=Type(FunctionType, .args=NULL, .ret=Type(IntType, .bits=TYPE_IBITS64)));
+ ast_t *default_rng = FakeAST(InlineCCode,
+ .code=CORD_all("((Closure_t){.fn=Int64$full_random})"),
+ .type=rng_fn);
+
if (streq(call->name, "insert")) {
CORD self = compile_to_pointer_depth(env, call->self, 1, false);
arg_t *arg_spec = new(arg_t, .name="item", .type=item_t,
@@ -2714,8 +2720,8 @@ CORD compile(env_t *env, ast_t *ast)
compile_type_info(env, self_value_t), ")");
} else if (streq(call->name, "random")) {
CORD self = compile_to_pointer_depth(env, call->self, 0, false);
- (void)compile_arguments(env, ast, NULL, call->args);
- return CORD_all("Array$random_value(", self, ", ", compile_type(item_t), ")");
+ arg_t *arg_spec = new(arg_t, .name="rng", .type=rng_fn, .default_val=default_rng);
+ return CORD_all("Array$random_value(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ", compile_type(item_t), ")");
} else if (streq(call->name, "has")) {
CORD self = compile_to_pointer_depth(env, call->self, 0, false);
arg_t *arg_spec = new(arg_t, .name="item", .type=item_t);
@@ -2731,12 +2737,12 @@ CORD compile(env_t *env, ast_t *ast)
padded_item_size, ")");
} else if (streq(call->name, "shuffle")) {
CORD self = compile_to_pointer_depth(env, call->self, 1, false);
- (void)compile_arguments(env, ast, NULL, call->args);
- return CORD_all("Array$shuffle(", self, ", ", padded_item_size, ")");
+ arg_t *arg_spec = new(arg_t, .name="rng", .type=rng_fn, .default_val=default_rng);
+ return CORD_all("Array$shuffle(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ", padded_item_size, ")");
} else if (streq(call->name, "shuffled")) {
CORD self = compile_to_pointer_depth(env, call->self, 0, false);
- (void)compile_arguments(env, ast, NULL, call->args);
- return CORD_all("Array$shuffled(", self, ", ", padded_item_size, ")");
+ arg_t *arg_spec = new(arg_t, .name="rng", .type=rng_fn, .default_val=default_rng);
+ return CORD_all("Array$shuffled(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ", padded_item_size, ")");
} else if (streq(call->name, "sort") || streq(call->name, "sorted")) {
CORD self = streq(call->name, "sort")
? compile_to_pointer_depth(env, call->self, 1, false)
diff --git a/stdlib/arrays.c b/stdlib/arrays.c
index 6b61f5b5..552fb4cb 100644
--- a/stdlib/arrays.c
+++ b/stdlib/arrays.c
@@ -249,36 +249,51 @@ public Array_t Array$sorted(Array_t arr, Closure_t comparison, int64_t padded_it
return arr;
}
-#pragma GCC diagnostic push
-#pragma GCC diagnostic ignored "-Wstack-protector"
-public void Array$shuffle(Array_t *arr, int64_t padded_item_size)
+static uint64_t random_range(Closure_t rng, uint64_t upper_bound)
+{
+ if (upper_bound < 2)
+ return 0;
+
+ // This approach is taken from arc4random_uniform()
+ uint64_t min = -upper_bound % upper_bound;
+ uint64_t r;
+ for (;;) {
+ r = ((uint64_t(*)(void*))rng.fn)(rng.userdata);
+ if (r >= min)
+ break;
+ }
+
+ return r % upper_bound;
+}
+
+public void Array$shuffle(Array_t *arr, Closure_t rng, int64_t padded_item_size)
{
if (arr->data_refcount != 0 || (int64_t)arr->stride != padded_item_size)
Array$compact(arr, padded_item_size);
char tmp[padded_item_size];
for (int64_t i = arr->length-1; i > 1; i--) {
- int64_t j = arc4random_uniform(i+1);
+ int64_t j = (int64_t)random_range(rng, (uint64_t)(i+1));
memcpy(tmp, arr->data + i*padded_item_size, (size_t)padded_item_size);
memcpy((void*)arr->data + i*padded_item_size, arr->data + j*padded_item_size, (size_t)padded_item_size);
memcpy((void*)arr->data + j*padded_item_size, tmp, (size_t)padded_item_size);
}
}
-#pragma GCC diagnostic pop
-public Array_t Array$shuffled(Array_t arr, int64_t padded_item_size)
+public Array_t Array$shuffled(Array_t arr, Closure_t rng, int64_t padded_item_size)
{
Array$compact(&arr, padded_item_size);
- Array$shuffle(&arr, padded_item_size);
+ Array$shuffle(&arr, rng, padded_item_size);
return arr;
}
-public void *Array$random(Array_t arr)
+public void *Array$random(Array_t arr, Closure_t rng)
{
if (arr.length == 0)
return NULL; // fail("Cannot get a random item from an empty array!");
- int64_t index = arc4random_uniform(arr.length);
- return arr.data + arr.stride*index;
+
+ uint64_t index = random_range(rng, (uint64_t)arr.length);
+ return arr.data + arr.stride*(int64_t)index;
}
public Table_t Array$counts(Array_t arr, const TypeInfo_t *type)
diff --git a/stdlib/arrays.h b/stdlib/arrays.h
index 03f00d49..5d452d30 100644
--- a/stdlib/arrays.h
+++ b/stdlib/arrays.h
@@ -70,10 +70,10 @@ Int_t Array$find(Array_t arr, void *item, const TypeInfo_t *type);
Int_t Array$first(Array_t arr, Closure_t predicate);
void Array$sort(Array_t *arr, Closure_t comparison, int64_t padded_item_size);
Array_t Array$sorted(Array_t arr, Closure_t comparison, int64_t padded_item_size);
-void Array$shuffle(Array_t *arr, int64_t padded_item_size);
-Array_t Array$shuffled(Array_t arr, int64_t padded_item_size);
-void *Array$random(Array_t arr);
-#define Array$random_value(arr, t) ({ Array_t _arr = arr; if (_arr.length == 0) fail("Cannot get a random value from an empty array!"); *(t*)Array$random(_arr); })
+void Array$shuffle(Array_t *arr, Closure_t rng, int64_t padded_item_size);
+Array_t Array$shuffled(Array_t arr, Closure_t rng, int64_t padded_item_size);
+void *Array$random(Array_t arr, Closure_t rng);
+#define Array$random_value(arr, rng, t) ({ Array_t _arr = arr; if (_arr.length == 0) fail("Cannot get a random value from an empty array!"); *(t*)Array$random(_arr, rng); })
Array_t Array$sample(Array_t arr, Int_t n, Array_t weights, int64_t padded_item_size);
Table_t Array$counts(Array_t arr, const TypeInfo_t *type);
void Array$clear(Array_t *array);
diff --git a/stdlib/integers.c b/stdlib/integers.c
index f604aa53..8d305daf 100644
--- a/stdlib/integers.c
+++ b/stdlib/integers.c
@@ -440,14 +440,16 @@ public const TypeInfo_t Int$info = {
} \
return bit_array; \
} \
+ public c_type KindOfInt ## $full_random(void) { \
+ c_type r; \
+ arc4random_buf(&r, sizeof(r)); \
+ return r; \
+ } \
public c_type KindOfInt ## $random(c_type min, c_type max) { \
if (min > max) fail("Random minimum value (%ld) is larger than the maximum value (%ld)", min, max); \
if (min == max) return min; \
- if (min == min_val && max == max_val) { \
- c_type r; \
- arc4random_buf(&r, sizeof(r)); \
- return r; \
- } \
+ if (min == min_val && max == max_val) \
+ return KindOfInt ## $full_random(); \
uint64_t range = (uint64_t)max - (uint64_t)min + 1; \
uint64_t min_r = -range % range; \
uint64_t r; \
diff --git a/stdlib/integers.h b/stdlib/integers.h
index e7b5b0e1..04699162 100644
--- a/stdlib/integers.h
+++ b/stdlib/integers.h
@@ -35,6 +35,7 @@
Text_t type_name ## $octal(c_type i, Int_t digits, bool prefix); \
Array_t type_name ## $bits(c_type x); \
c_type type_name ## $random(c_type min, c_type max); \
+ c_type type_name ## $full_random(void); \
to_attr Range_t type_name ## $to(c_type from, c_type to); \
PUREFUNC Optional ## type_name ## _t type_name ## $from_text(Text_t text); \
MACROLIKE PUREFUNC c_type type_name ## $clamped(c_type x, c_type min, c_type max) { \
diff --git a/test/arrays.tm b/test/arrays.tm
index e1c31cf6..25d45582 100644
--- a/test/arrays.tm
+++ b/test/arrays.tm
@@ -173,3 +173,28 @@ func main():
>> [4, 5, 6]:first(func(i:&Int): i:is_prime())
= 2?
+ test_seeded_rng()
+
+# Inspired by: https://nullprogram.com/blog/2017/09/21/
+struct RNGxoroshiro128(s0=Int64.random(),s1=Int64.random()):
+ func int_fn(rng:@RNGxoroshiro128 -> func(->Int64)):
+ return func(-> Int64):
+ s0 := rng.s0
+ s1 := rng.s1
+ result := s0:wrapping_plus(s1)
+ s1 xor= s0
+ rng.s0 = (s0 <<< 55) or (s0 >>> 9) xor s1 xor (s1 <<< 14)
+ rng.s1 = (s1 <<< 36) or (s1 >>> 28)
+ return result
+
+func test_seeded_rng():
+ !! Seeded RNG:
+ rng_state := RNGxoroshiro128(Int64.random(), Int64.random())
+ rng1 := @rng_state:int_fn()
+ rng2 := @rng_state:int_fn()
+
+ nums := [i*10 for i in 20]
+ >> nums:random(rng1) == nums:random(rng2)
+ = yes
+ >> nums:shuffled(rng1) == nums:shuffled(rng2)
+ = yes