aboutsummaryrefslogtreecommitdiff
path: root/src/stdlib/arrays.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/stdlib/arrays.c')
-rw-r--r--src/stdlib/arrays.c71
1 files changed, 54 insertions, 17 deletions
diff --git a/src/stdlib/arrays.c b/src/stdlib/arrays.c
index 6579536b..b018012f 100644
--- a/src/stdlib/arrays.c
+++ b/src/stdlib/arrays.c
@@ -10,7 +10,6 @@
#include "math.h"
#include "metamethods.h"
#include "optionals.h"
-#include "rng.h"
#include "tables.h"
#include "text.h"
#include "util.h"
@@ -271,33 +270,63 @@ public Array_t Array$sorted(Array_t arr, Closure_t comparison, int64_t padded_it
return arr;
}
-public void Array$shuffle(Array_t *arr, RNG_t rng, int64_t padded_item_size)
+#if defined(__FreeBSD__) || defined(__OpenBSD__) || defined(__NetBSD__) || defined(__APPLE__)
+static ssize_t getrandom(void *buf, size_t buflen, unsigned int flags) {
+ (void)flags;
+ arc4random_buf(buf, buflen);
+ return buflen;
+}
+#elif defined(__linux__)
+// Use getrandom()
+# include <sys/random.h>
+#else
+ #error "Unsupported platform for secure random number generation"
+#endif
+
+static int64_t _default_random_int64(int64_t min, int64_t max, void *userdata)
+{
+ (void)userdata;
+ if (min > max) fail("Random minimum value (", min, ") is larger than the maximum value (", max, ")");
+ if (min == max) return min;
+ uint64_t range = (uint64_t)max - (uint64_t)min + 1;
+ uint64_t min_r = -range % range;
+ uint64_t r;
+ for (;;) {
+ getrandom(&r, sizeof(r), 0);
+ if (r >= min_r) break;
+ }
+ return (int64_t)((uint64_t)min + (r % range));
+}
+
+public void Array$shuffle(Array_t *arr, OptionalClosure_t random_int64, int64_t padded_item_size)
{
if (arr->data_refcount != 0 || (int64_t)arr->stride != padded_item_size)
Array$compact(arr, padded_item_size);
+ int64_t (*rng_fn)(int64_t, int64_t, void*) = random_int64.fn ? random_int64.fn : _default_random_int64;
char tmp[padded_item_size];
for (int64_t i = arr->length-1; i > 1; i--) {
- int64_t j = RNG$int64(rng, 0, i);
+ int64_t j = rng_fn(0, i, random_int64.userdata);
memcpy(tmp, arr->data + i*padded_item_size, (size_t)padded_item_size);
memcpy((void*)arr->data + i*padded_item_size, arr->data + j*padded_item_size, (size_t)padded_item_size);
memcpy((void*)arr->data + j*padded_item_size, tmp, (size_t)padded_item_size);
}
}
-public Array_t Array$shuffled(Array_t arr, RNG_t rng, int64_t padded_item_size)
+public Array_t Array$shuffled(Array_t arr, Closure_t random_int64, int64_t padded_item_size)
{
Array$compact(&arr, padded_item_size);
- Array$shuffle(&arr, rng, padded_item_size);
+ Array$shuffle(&arr, random_int64, padded_item_size);
return arr;
}
-public void *Array$random(Array_t arr, RNG_t rng)
+public void *Array$random(Array_t arr, OptionalClosure_t random_int64)
{
if (arr.length == 0)
return NULL; // fail("Cannot get a random item from an empty array!");
- int64_t index = RNG$int64(rng, 0, arr.length-1);
+ int64_t (*rng_fn)(int64_t, int64_t, void*) = random_int64.fn;
+ int64_t index = rng_fn(0, arr.length-1, random_int64.userdata);
return arr.data + arr.stride*index;
}
@@ -314,7 +343,22 @@ public Table_t Array$counts(Array_t arr, const TypeInfo_t *type)
return counts;
}
-public Array_t Array$sample(Array_t arr, Int_t int_n, Array_t weights, RNG_t rng, int64_t padded_item_size)
+static double _default_random_num(void *userdata)
+{
+ (void)userdata;
+ union {
+ Num_t num;
+ uint64_t bits;
+ } r = {.bits=0}, one = {.num=1.0};
+ getrandom((uint8_t*)&r, sizeof(r), 0);
+
+ // Set r.num to 1.<random-bits>
+ r.bits &= ~(0xFFFULL << 52);
+ r.bits |= (one.bits & (0xFFFULL << 52));
+ return r.num - 1.0;
+}
+
+public Array_t Array$sample(Array_t arr, Int_t int_n, Array_t weights, OptionalClosure_t random_num, int64_t padded_item_size)
{
int64_t n = Int64$from_int(int_n, false);
if (n < 0)
@@ -331,14 +375,6 @@ public Array_t Array$sample(Array_t arr, Int_t int_n, Array_t weights, RNG_t rng
.length=n,
.stride=padded_item_size, .atomic=arr.atomic};
- if (weights.length < 0) {
- for (int64_t i = 0; i < n; i++) {
- int64_t index = RNG$int64(rng, 0, arr.length-1);
- memcpy(selected.data + i*padded_item_size, arr.data + arr.stride*index, (size_t)padded_item_size);
- }
- return selected;
- }
-
if (weights.length != arr.length)
fail("Array has ", (int64_t)arr.length, " elements, but there are ", (int64_t)weights.length, " weights given");
@@ -396,8 +432,9 @@ public Array_t Array$sample(Array_t arr, Int_t int_n, Array_t weights, RNG_t rng
if (aliases[i].alias == -1)
aliases[i].alias = i;
+ double (*rng_fn)(void*) = random_num.fn ? random_num.fn : _default_random_num;
for (int64_t i = 0; i < n; i++) {
- double r = RNG$num(rng, 0, arr.length);
+ double r = (double)arr.length * rng_fn(random_num.userdata);
int64_t index = (int64_t)r;
if ((r - (double)index) > aliases[index].odds)
index = aliases[index].alias;