aboutsummaryrefslogtreecommitdiff
path: root/stdlib/arrays.c
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2024-10-02 14:42:51 -0400
committerBruce Hill <bruce@bruce-hill.com>2024-10-02 14:42:51 -0400
commit54e336e30f2112dc2cb0c1b85dd153667680e550 (patch)
tree6d933b3479c08e7d2f0ec00a40badcfbdb28002e /stdlib/arrays.c
parentc8c137639c99793a7a2136f2bdc8de903bf4b5ec (diff)
Update array:sample() to use optional weights and do more error checking
Diffstat (limited to 'stdlib/arrays.c')
-rw-r--r--stdlib/arrays.c93
1 files changed, 53 insertions, 40 deletions
diff --git a/stdlib/arrays.c b/stdlib/arrays.c
index 7c4ae94e..274983b7 100644
--- a/stdlib/arrays.c
+++ b/stdlib/arrays.c
@@ -296,14 +296,31 @@ public Table_t Array$counts(Array_t arr, const TypeInfo_t *type)
public Array_t Array$sample(Array_t arr, Int_t int_n, Array_t weights, int64_t padded_item_size)
{
int64_t n = Int_to_Int64(int_n, false);
- if (arr.length == 0 || n <= 0)
+ if (n < 0)
+ fail("Cannot select a negative number of values");
+
+ if (n == 0)
return (Array_t){};
+ if (arr.length == 0)
+ fail("There are no elements in this array!");
+
Array_t selected = {
.data=arr.atomic ? GC_MALLOC_ATOMIC((size_t)(n * padded_item_size)) : GC_MALLOC((size_t)(n * padded_item_size)),
.length=n,
.stride=padded_item_size, .atomic=arr.atomic};
+ if (weights.length < 0) {
+ for (int64_t i = 0; i < n; i++) {
+ int64_t index = arc4random_uniform(arr.length);
+ memcpy(selected.data + i*padded_item_size, arr.data + arr.stride*index, (size_t)padded_item_size);
+ }
+ return selected;
+ }
+
+ if (weights.length != arr.length)
+ fail("Array has %ld elements, but there are %ld weights given", arr.length, weights.length);
+
double total = 0.0;
for (int64_t i = 0; i < weights.length && i < arr.length; i++) {
double weight = *(double*)(weights.data + weights.stride*i);
@@ -320,54 +337,50 @@ public Array_t Array$sample(Array_t arr, Int_t int_n, Array_t weights, int64_t p
if (isinf(total))
fail("Sample weights have overflowed to infinity");
- if (total == 0.0) {
- for (int64_t i = 0; i < n; i++) {
- int64_t index = arc4random_uniform(arr.length);
- memcpy(selected.data + i*padded_item_size, arr.data + arr.stride*index, (size_t)padded_item_size);
- }
- } else {
- double inverse_average = (double)arr.length / total;
+ if (total == 0.0)
+ fail("None of the given weights are nonzero");
- struct {
- int64_t alias;
- double odds;
- } aliases[arr.length] = {};
+ double inverse_average = (double)arr.length / total;
- for (int64_t i = 0; i < arr.length; i++) {
- double weight = i >= weights.length ? 0.0 : *(double*)(weights.data + weights.stride*i);
- aliases[i].odds = weight * inverse_average;
- aliases[i].alias = -1;
- }
+ struct {
+ int64_t alias;
+ double odds;
+ } aliases[arr.length] = {};
- int64_t small = 0;
- for (int64_t big = 0; big < arr.length; big++) {
- while (aliases[big].odds >= 1.0) {
- while (small < arr.length && (aliases[small].odds >= 1.0 || aliases[small].alias != -1))
- ++small;
+ for (int64_t i = 0; i < arr.length; i++) {
+ double weight = i >= weights.length ? 0.0 : *(double*)(weights.data + weights.stride*i);
+ aliases[i].odds = weight * inverse_average;
+ aliases[i].alias = -1;
+ }
- if (small >= arr.length) {
- aliases[big].odds = 1.0;
- aliases[big].alias = big;
- break;
- }
+ int64_t small = 0;
+ for (int64_t big = 0; big < arr.length; big++) {
+ while (aliases[big].odds >= 1.0) {
+ while (small < arr.length && (aliases[small].odds >= 1.0 || aliases[small].alias != -1))
+ ++small;
- aliases[small].alias = big;
- aliases[big].odds = (aliases[small].odds + aliases[big].odds) - 1.0;
+ if (small >= arr.length) {
+ aliases[big].odds = 1.0;
+ aliases[big].alias = big;
+ break;
}
- if (big < small) small = big;
+
+ aliases[small].alias = big;
+ aliases[big].odds = (aliases[small].odds + aliases[big].odds) - 1.0;
}
+ if (big < small) small = big;
+ }
- for (int64_t i = small; i < arr.length; i++)
- if (aliases[i].alias == -1)
- aliases[i].alias = i;
+ for (int64_t i = small; i < arr.length; i++)
+ if (aliases[i].alias == -1)
+ aliases[i].alias = i;
- for (int64_t i = 0; i < n; i++) {
- double r = drand48() * arr.length;
- int64_t index = (int64_t)r;
- if ((r - (double)index) > aliases[index].odds)
- index = aliases[index].alias;
- memcpy(selected.data + i*selected.stride, arr.data + index*arr.stride, (size_t)padded_item_size);
- }
+ for (int64_t i = 0; i < n; i++) {
+ double r = drand48() * arr.length;
+ int64_t index = (int64_t)r;
+ if ((r - (double)index) > aliases[index].odds)
+ index = aliases[index].alias;
+ memcpy(selected.data + i*selected.stride, arr.data + index*arr.stride, (size_t)padded_item_size);
}
return selected;
}