diff options
Diffstat (limited to 'stdlib')
| -rw-r--r-- | stdlib/arrays.c | 93 |
1 files changed, 53 insertions, 40 deletions
diff --git a/stdlib/arrays.c b/stdlib/arrays.c index 7c4ae94e..274983b7 100644 --- a/stdlib/arrays.c +++ b/stdlib/arrays.c @@ -296,14 +296,31 @@ public Table_t Array$counts(Array_t arr, const TypeInfo_t *type) public Array_t Array$sample(Array_t arr, Int_t int_n, Array_t weights, int64_t padded_item_size) { int64_t n = Int_to_Int64(int_n, false); - if (arr.length == 0 || n <= 0) + if (n < 0) + fail("Cannot select a negative number of values"); + + if (n == 0) return (Array_t){}; + if (arr.length == 0) + fail("There are no elements in this array!"); + Array_t selected = { .data=arr.atomic ? GC_MALLOC_ATOMIC((size_t)(n * padded_item_size)) : GC_MALLOC((size_t)(n * padded_item_size)), .length=n, .stride=padded_item_size, .atomic=arr.atomic}; + if (weights.length < 0) { + for (int64_t i = 0; i < n; i++) { + int64_t index = arc4random_uniform(arr.length); + memcpy(selected.data + i*padded_item_size, arr.data + arr.stride*index, (size_t)padded_item_size); + } + return selected; + } + + if (weights.length != arr.length) + fail("Array has %ld elements, but there are %ld weights given", arr.length, weights.length); + double total = 0.0; for (int64_t i = 0; i < weights.length && i < arr.length; i++) { double weight = *(double*)(weights.data + weights.stride*i); @@ -320,54 +337,50 @@ public Array_t Array$sample(Array_t arr, Int_t int_n, Array_t weights, int64_t p if (isinf(total)) fail("Sample weights have overflowed to infinity"); - if (total == 0.0) { - for (int64_t i = 0; i < n; i++) { - int64_t index = arc4random_uniform(arr.length); - memcpy(selected.data + i*padded_item_size, arr.data + arr.stride*index, (size_t)padded_item_size); - } - } else { - double inverse_average = (double)arr.length / total; + if (total == 0.0) + fail("None of the given weights are nonzero"); - struct { - int64_t alias; - double odds; - } aliases[arr.length] = {}; + double inverse_average = (double)arr.length / total; - for (int64_t i = 0; i < arr.length; i++) { - double weight = i >= weights.length ? 0.0 : *(double*)(weights.data + weights.stride*i); - aliases[i].odds = weight * inverse_average; - aliases[i].alias = -1; - } + struct { + int64_t alias; + double odds; + } aliases[arr.length] = {}; - int64_t small = 0; - for (int64_t big = 0; big < arr.length; big++) { - while (aliases[big].odds >= 1.0) { - while (small < arr.length && (aliases[small].odds >= 1.0 || aliases[small].alias != -1)) - ++small; + for (int64_t i = 0; i < arr.length; i++) { + double weight = i >= weights.length ? 0.0 : *(double*)(weights.data + weights.stride*i); + aliases[i].odds = weight * inverse_average; + aliases[i].alias = -1; + } - if (small >= arr.length) { - aliases[big].odds = 1.0; - aliases[big].alias = big; - break; - } + int64_t small = 0; + for (int64_t big = 0; big < arr.length; big++) { + while (aliases[big].odds >= 1.0) { + while (small < arr.length && (aliases[small].odds >= 1.0 || aliases[small].alias != -1)) + ++small; - aliases[small].alias = big; - aliases[big].odds = (aliases[small].odds + aliases[big].odds) - 1.0; + if (small >= arr.length) { + aliases[big].odds = 1.0; + aliases[big].alias = big; + break; } - if (big < small) small = big; + + aliases[small].alias = big; + aliases[big].odds = (aliases[small].odds + aliases[big].odds) - 1.0; } + if (big < small) small = big; + } - for (int64_t i = small; i < arr.length; i++) - if (aliases[i].alias == -1) - aliases[i].alias = i; + for (int64_t i = small; i < arr.length; i++) + if (aliases[i].alias == -1) + aliases[i].alias = i; - for (int64_t i = 0; i < n; i++) { - double r = drand48() * arr.length; - int64_t index = (int64_t)r; - if ((r - (double)index) > aliases[index].odds) - index = aliases[index].alias; - memcpy(selected.data + i*selected.stride, arr.data + index*arr.stride, (size_t)padded_item_size); - } + for (int64_t i = 0; i < n; i++) { + double r = drand48() * arr.length; + int64_t index = (int64_t)r; + if ((r - (double)index) > aliases[index].odds) + index = aliases[index].alias; + memcpy(selected.data + i*selected.stride, arr.data + index*arr.stride, (size_t)padded_item_size); } return selected; } |
