aboutsummaryrefslogtreecommitdiff
path: root/builtins
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2024-04-02 23:28:59 -0400
committerBruce Hill <bruce@bruce-hill.com>2024-04-02 23:28:59 -0400
commitfae2b2caa0ec311821a7a7ef1f98500cfd25bb9e (patch)
tree9dcc467733b19705b51e7fe7b07a98800f1efe6d /builtins
parent95100469b6c7f301bb14bcda5dbc16b93c9ce0dc (diff)
Add array:sample()
Diffstat (limited to 'builtins')
-rw-r--r--builtins/array.c80
-rw-r--r--builtins/array.h1
2 files changed, 81 insertions, 0 deletions
diff --git a/builtins/array.c b/builtins/array.c
index 4eea25fb..333a2c50 100644
--- a/builtins/array.c
+++ b/builtins/array.c
@@ -4,6 +4,7 @@
#include <err.h>
#include <gc.h>
#include <gc/cord.h>
+#include <math.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>
@@ -181,6 +182,85 @@ public void *Array$random(array_t arr)
return arr.data + arr.stride*index;
}
+public array_t Array$sample(array_t arr, int64_t n, array_t weights, const TypeInfo *type)
+{
+ if (arr.length == 0 || n <= 0)
+ return (array_t){};
+
+ int64_t item_size = get_item_size(type);
+ array_t selected = {
+ .data=arr.atomic ? GC_MALLOC_ATOMIC(n * item_size) : GC_MALLOC(n * item_size),
+ .length=n,
+ .stride=item_size, .atomic=arr.atomic};
+
+ double total = 0.0;
+ for (int64_t i = 0; i < weights.length && i < arr.length; i++) {
+ double weight = *(double*)(weights.data + weights.stride*i);
+ if (isinf(weight))
+ fail("Infinite weight!");
+ else if (isnan(weight))
+ fail("NaN weight!");
+ else if (weight < 0.0)
+ fail("Negative weight!");
+ else
+ total += weight;
+ }
+
+ if (isinf(total))
+ fail("Sample weights have overflowed to infinity");
+
+ if (total == 0.0) {
+ for (int64_t i = 0; i < n; i++) {
+ uint32_t index = arc4random_uniform(arr.length);
+ memcpy(selected.data + i*item_size, arr.data + arr.stride*index, item_size);
+ }
+ } else {
+ double inverse_average = (double)arr.length / total;
+
+ struct {
+ int64_t alias;
+ double odds;
+ } aliases[arr.length] = {};
+
+ for (int64_t i = 0; i < arr.length; i++) {
+ double weight = i >= weights.length ? 0.0 : *(double*)(weights.data + weights.stride*i);
+ aliases[i].odds = weight * inverse_average;
+ aliases[i].alias = -1;
+ }
+
+ int64_t small = 0;
+ for (int64_t big = 0; big < arr.length; big++) {
+ while (aliases[big].odds >= 1.0) {
+ while (small < arr.length && (aliases[small].odds >= 1.0 || aliases[small].alias != -1))
+ ++small;
+
+ if (small >= arr.length) {
+ aliases[big].odds = 1.0;
+ aliases[big].alias = big;
+ break;
+ }
+
+ aliases[small].alias = big;
+ aliases[big].odds = (aliases[small].odds + aliases[big].odds) - 1.0;
+ }
+ if (big < small) small = big;
+ }
+
+ for (int64_t i = small; i < arr.length; i++)
+ if (aliases[i].alias == -1)
+ aliases[i].alias = i;
+
+ for (int64_t i = 0; i < n; i++) {
+ double r = drand48() * arr.length;
+ int64_t index = (int64_t)r;
+ if ((r - (double)index) > aliases[index].odds)
+ index = aliases[index].alias;
+ memcpy(selected.data + i*selected.stride, arr.data + index*arr.stride, item_size);
+ }
+ }
+ return selected;
+}
+
public array_t Array$slice(array_t *array, int64_t first, int64_t length, int64_t stride, const TypeInfo *type)
{
if (stride > MAX_STRIDE || stride < MIN_STRIDE)
diff --git a/builtins/array.h b/builtins/array.h
index c6722e6d..18757850 100644
--- a/builtins/array.h
+++ b/builtins/array.h
@@ -60,6 +60,7 @@ void Array$sort(array_t *arr, closure_t comparison, const TypeInfo *type);
array_t Array$sorted(array_t arr, closure_t comparison, const TypeInfo *type);
void Array$shuffle(array_t *arr, const TypeInfo *type);
void *Array$random(array_t arr);
+array_t Array$sample(array_t arr, int64_t n, array_t weights, const TypeInfo *type);
void Array$clear(array_t *array);
void Array$compact(array_t *arr, const TypeInfo *type);
bool Array$contains(array_t array, void *item, const TypeInfo *type);