From 54e336e30f2112dc2cb0c1b85dd153667680e550 Mon Sep 17 00:00:00 2001 From: Bruce Hill Date: Wed, 2 Oct 2024 14:42:51 -0400 Subject: [PATCH] Update array:sample() to use optional weights and do more error checking --- compile.c | 5 +-- docs/arrays.md | 16 ++++++--- stdlib/arrays.c | 95 ++++++++++++++++++++++++++++--------------------- 3 files changed, 68 insertions(+), 48 deletions(-) diff --git a/compile.c b/compile.c index 5077d49..89a480e 100644 --- a/compile.c +++ b/compile.c @@ -2620,8 +2620,9 @@ CORD compile(env_t *env, ast_t *ast) } else if (streq(call->name, "sample")) { CORD self = compile_to_pointer_depth(env, call->self, 0, false); arg_t *arg_spec = new(arg_t, .name="count", .type=INT_TYPE, - .next=new(arg_t, .name="weights", .type=Type(ArrayType, .item_type=Type(NumType)), - .default_val=FakeAST(Array, .item_type=new(type_ast_t, .tag=VarTypeAST, .__data.VarTypeAST.name="Num")))); + .next=new(arg_t, .name="weights", .type=Type(ArrayType, .item_type=Type(NumType)), + .default_val=FakeAST(Null, .type=new(type_ast_t, .tag=ArrayTypeAST, + .__data.ArrayTypeAST.item=new(type_ast_t, .tag=VarTypeAST, .__data.VarTypeAST.name="Num"))))); return CORD_all("Array$sample(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ", padded_item_size, ")"); } else if (streq(call->name, "shuffle")) { diff --git a/docs/arrays.md b/docs/arrays.md index bb5af7c..edde43a 100644 --- a/docs/arrays.md +++ b/docs/arrays.md @@ -731,7 +731,7 @@ probabilities. **Usage:** ```markdown -sample(arr: [T], count: Int, weights: [Num]) -> [T] +sample(arr: [T], count: Int, weights: [Num]? = ![Num]) -> [T] ``` **Parameters:** @@ -740,10 +740,16 @@ sample(arr: [T], count: Int, weights: [Num]) -> [T] - `count`: The number of elements to sample. - `weights`: The probability weights for each element in the array. These values do not need to add up to any particular number, they are relative - weights. If no weights are provided, the default is equal probabilities. - Negative, infinite, or NaN weights will cause a runtime error. If the number of - weights given is less than the length of the array, elements from the rest of - the array are considered to have zero weight. + weights. If no weights are given, elements will be sampled with uniform + probability. + +**Errors:** +Errors will be raised if any of the following conditions occurs: +- The given array has no elements and `count >= 1` +- `count < 0` (negative count) +- The number of weights provided doesn't match the length of the array. +- Any weight in the weights array is negative, infinite, or `NaN` +- The sum of the given weights is zero (zero probability for every element). **Returns:** A list of sampled elements from the array. diff --git a/stdlib/arrays.c b/stdlib/arrays.c index 7c4ae94..274983b 100644 --- a/stdlib/arrays.c +++ b/stdlib/arrays.c @@ -296,14 +296,31 @@ public Table_t Array$counts(Array_t arr, const TypeInfo_t *type) public Array_t Array$sample(Array_t arr, Int_t int_n, Array_t weights, int64_t padded_item_size) { int64_t n = Int_to_Int64(int_n, false); - if (arr.length == 0 || n <= 0) + if (n < 0) + fail("Cannot select a negative number of values"); + + if (n == 0) return (Array_t){}; + if (arr.length == 0) + fail("There are no elements in this array!"); + Array_t selected = { .data=arr.atomic ? GC_MALLOC_ATOMIC((size_t)(n * padded_item_size)) : GC_MALLOC((size_t)(n * padded_item_size)), .length=n, .stride=padded_item_size, .atomic=arr.atomic}; + if (weights.length < 0) { + for (int64_t i = 0; i < n; i++) { + int64_t index = arc4random_uniform(arr.length); + memcpy(selected.data + i*padded_item_size, arr.data + arr.stride*index, (size_t)padded_item_size); + } + return selected; + } + + if (weights.length != arr.length) + fail("Array has %ld elements, but there are %ld weights given", arr.length, weights.length); + double total = 0.0; for (int64_t i = 0; i < weights.length && i < arr.length; i++) { double weight = *(double*)(weights.data + weights.stride*i); @@ -320,54 +337,50 @@ public Array_t Array$sample(Array_t arr, Int_t int_n, Array_t weights, int64_t p if (isinf(total)) fail("Sample weights have overflowed to infinity"); - if (total == 0.0) { - for (int64_t i = 0; i < n; i++) { - int64_t index = arc4random_uniform(arr.length); - memcpy(selected.data + i*padded_item_size, arr.data + arr.stride*index, (size_t)padded_item_size); - } - } else { - double inverse_average = (double)arr.length / total; + if (total == 0.0) + fail("None of the given weights are nonzero"); - struct { - int64_t alias; - double odds; - } aliases[arr.length] = {}; + double inverse_average = (double)arr.length / total; - for (int64_t i = 0; i < arr.length; i++) { - double weight = i >= weights.length ? 0.0 : *(double*)(weights.data + weights.stride*i); - aliases[i].odds = weight * inverse_average; - aliases[i].alias = -1; - } + struct { + int64_t alias; + double odds; + } aliases[arr.length] = {}; - int64_t small = 0; - for (int64_t big = 0; big < arr.length; big++) { - while (aliases[big].odds >= 1.0) { - while (small < arr.length && (aliases[small].odds >= 1.0 || aliases[small].alias != -1)) - ++small; + for (int64_t i = 0; i < arr.length; i++) { + double weight = i >= weights.length ? 0.0 : *(double*)(weights.data + weights.stride*i); + aliases[i].odds = weight * inverse_average; + aliases[i].alias = -1; + } - if (small >= arr.length) { - aliases[big].odds = 1.0; - aliases[big].alias = big; - break; - } + int64_t small = 0; + for (int64_t big = 0; big < arr.length; big++) { + while (aliases[big].odds >= 1.0) { + while (small < arr.length && (aliases[small].odds >= 1.0 || aliases[small].alias != -1)) + ++small; - aliases[small].alias = big; - aliases[big].odds = (aliases[small].odds + aliases[big].odds) - 1.0; + if (small >= arr.length) { + aliases[big].odds = 1.0; + aliases[big].alias = big; + break; } - if (big < small) small = big; - } - for (int64_t i = small; i < arr.length; i++) - if (aliases[i].alias == -1) - aliases[i].alias = i; - - for (int64_t i = 0; i < n; i++) { - double r = drand48() * arr.length; - int64_t index = (int64_t)r; - if ((r - (double)index) > aliases[index].odds) - index = aliases[index].alias; - memcpy(selected.data + i*selected.stride, arr.data + index*arr.stride, (size_t)padded_item_size); + aliases[small].alias = big; + aliases[big].odds = (aliases[small].odds + aliases[big].odds) - 1.0; } + if (big < small) small = big; + } + + for (int64_t i = small; i < arr.length; i++) + if (aliases[i].alias == -1) + aliases[i].alias = i; + + for (int64_t i = 0; i < n; i++) { + double r = drand48() * arr.length; + int64_t index = (int64_t)r; + if ((r - (double)index) > aliases[index].odds) + index = aliases[index].alias; + memcpy(selected.data + i*selected.stride, arr.data + index*arr.stride, (size_t)padded_item_size); } return selected; }