diff options
| author | Bruce Hill <bruce@bruce-hill.com> | 2024-04-02 23:28:59 -0400 |
|---|---|---|
| committer | Bruce Hill <bruce@bruce-hill.com> | 2024-04-02 23:28:59 -0400 |
| commit | fae2b2caa0ec311821a7a7ef1f98500cfd25bb9e (patch) | |
| tree | 9dcc467733b19705b51e7fe7b07a98800f1efe6d /builtins/array.c | |
| parent | 95100469b6c7f301bb14bcda5dbc16b93c9ce0dc (diff) | |
Add array:sample()
Diffstat (limited to 'builtins/array.c')
| -rw-r--r-- | builtins/array.c | 80 |
1 files changed, 80 insertions, 0 deletions
diff --git a/builtins/array.c b/builtins/array.c index 4eea25fb..333a2c50 100644 --- a/builtins/array.c +++ b/builtins/array.c @@ -4,6 +4,7 @@ #include <err.h> #include <gc.h> #include <gc/cord.h> +#include <math.h> #include <stdbool.h> #include <stdint.h> #include <stdlib.h> @@ -181,6 +182,85 @@ public void *Array$random(array_t arr) return arr.data + arr.stride*index; } +public array_t Array$sample(array_t arr, int64_t n, array_t weights, const TypeInfo *type) +{ + if (arr.length == 0 || n <= 0) + return (array_t){}; + + int64_t item_size = get_item_size(type); + array_t selected = { + .data=arr.atomic ? GC_MALLOC_ATOMIC(n * item_size) : GC_MALLOC(n * item_size), + .length=n, + .stride=item_size, .atomic=arr.atomic}; + + double total = 0.0; + for (int64_t i = 0; i < weights.length && i < arr.length; i++) { + double weight = *(double*)(weights.data + weights.stride*i); + if (isinf(weight)) + fail("Infinite weight!"); + else if (isnan(weight)) + fail("NaN weight!"); + else if (weight < 0.0) + fail("Negative weight!"); + else + total += weight; + } + + if (isinf(total)) + fail("Sample weights have overflowed to infinity"); + + if (total == 0.0) { + for (int64_t i = 0; i < n; i++) { + uint32_t index = arc4random_uniform(arr.length); + memcpy(selected.data + i*item_size, arr.data + arr.stride*index, item_size); + } + } else { + double inverse_average = (double)arr.length / total; + + struct { + int64_t alias; + double odds; + } aliases[arr.length] = {}; + + for (int64_t i = 0; i < arr.length; i++) { + double weight = i >= weights.length ? 0.0 : *(double*)(weights.data + weights.stride*i); + aliases[i].odds = weight * inverse_average; + aliases[i].alias = -1; + } + + int64_t small = 0; + for (int64_t big = 0; big < arr.length; big++) { + while (aliases[big].odds >= 1.0) { + while (small < arr.length && (aliases[small].odds >= 1.0 || aliases[small].alias != -1)) + ++small; + + if (small >= arr.length) { + aliases[big].odds = 1.0; + aliases[big].alias = big; + break; + } + + aliases[small].alias = big; + aliases[big].odds = (aliases[small].odds + aliases[big].odds) - 1.0; + } + if (big < small) small = big; + } + + for (int64_t i = small; i < arr.length; i++) + if (aliases[i].alias == -1) + aliases[i].alias = i; + + for (int64_t i = 0; i < n; i++) { + double r = drand48() * arr.length; + int64_t index = (int64_t)r; + if ((r - (double)index) > aliases[index].odds) + index = aliases[index].alias; + memcpy(selected.data + i*selected.stride, arr.data + index*arr.stride, item_size); + } + } + return selected; +} + public array_t Array$slice(array_t *array, int64_t first, int64_t length, int64_t stride, const TypeInfo *type) { if (stride > MAX_STRIDE || stride < MIN_STRIDE) |
