aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2024-10-02 14:42:51 -0400
committerBruce Hill <bruce@bruce-hill.com>2024-10-02 14:42:51 -0400
commit54e336e30f2112dc2cb0c1b85dd153667680e550 (patch)
tree6d933b3479c08e7d2f0ec00a40badcfbdb28002e
parentc8c137639c99793a7a2136f2bdc8de903bf4b5ec (diff)
Update array:sample() to use optional weights and do more error checking
-rw-r--r--compile.c5
-rw-r--r--docs/arrays.md16
-rw-r--r--stdlib/arrays.c93
3 files changed, 67 insertions, 47 deletions
diff --git a/compile.c b/compile.c
index 5077d499..89a480e7 100644
--- a/compile.c
+++ b/compile.c
@@ -2620,8 +2620,9 @@ CORD compile(env_t *env, ast_t *ast)
} else if (streq(call->name, "sample")) {
CORD self = compile_to_pointer_depth(env, call->self, 0, false);
arg_t *arg_spec = new(arg_t, .name="count", .type=INT_TYPE,
- .next=new(arg_t, .name="weights", .type=Type(ArrayType, .item_type=Type(NumType)),
- .default_val=FakeAST(Array, .item_type=new(type_ast_t, .tag=VarTypeAST, .__data.VarTypeAST.name="Num"))));
+ .next=new(arg_t, .name="weights", .type=Type(ArrayType, .item_type=Type(NumType)),
+ .default_val=FakeAST(Null, .type=new(type_ast_t, .tag=ArrayTypeAST,
+ .__data.ArrayTypeAST.item=new(type_ast_t, .tag=VarTypeAST, .__data.VarTypeAST.name="Num")))));
return CORD_all("Array$sample(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ",
padded_item_size, ")");
} else if (streq(call->name, "shuffle")) {
diff --git a/docs/arrays.md b/docs/arrays.md
index bb5af7c1..edde43af 100644
--- a/docs/arrays.md
+++ b/docs/arrays.md
@@ -731,7 +731,7 @@ probabilities.
**Usage:**
```markdown
-sample(arr: [T], count: Int, weights: [Num]) -> [T]
+sample(arr: [T], count: Int, weights: [Num]? = ![Num]) -> [T]
```
**Parameters:**
@@ -740,10 +740,16 @@ sample(arr: [T], count: Int, weights: [Num]) -> [T]
- `count`: The number of elements to sample.
- `weights`: The probability weights for each element in the array. These
values do not need to add up to any particular number, they are relative
- weights. If no weights are provided, the default is equal probabilities.
- Negative, infinite, or NaN weights will cause a runtime error. If the number of
- weights given is less than the length of the array, elements from the rest of
- the array are considered to have zero weight.
+ weights. If no weights are given, elements will be sampled with uniform
+ probability.
+
+**Errors:**
+Errors will be raised if any of the following conditions occurs:
+- The given array has no elements and `count >= 1`
+- `count < 0` (negative count)
+- The number of weights provided doesn't match the length of the array.
+- Any weight in the weights array is negative, infinite, or `NaN`
+- The sum of the given weights is zero (zero probability for every element).
**Returns:**
A list of sampled elements from the array.
diff --git a/stdlib/arrays.c b/stdlib/arrays.c
index 7c4ae94e..274983b7 100644
--- a/stdlib/arrays.c
+++ b/stdlib/arrays.c
@@ -296,14 +296,31 @@ public Table_t Array$counts(Array_t arr, const TypeInfo_t *type)
public Array_t Array$sample(Array_t arr, Int_t int_n, Array_t weights, int64_t padded_item_size)
{
int64_t n = Int_to_Int64(int_n, false);
- if (arr.length == 0 || n <= 0)
+ if (n < 0)
+ fail("Cannot select a negative number of values");
+
+ if (n == 0)
return (Array_t){};
+ if (arr.length == 0)
+ fail("There are no elements in this array!");
+
Array_t selected = {
.data=arr.atomic ? GC_MALLOC_ATOMIC((size_t)(n * padded_item_size)) : GC_MALLOC((size_t)(n * padded_item_size)),
.length=n,
.stride=padded_item_size, .atomic=arr.atomic};
+ if (weights.length < 0) {
+ for (int64_t i = 0; i < n; i++) {
+ int64_t index = arc4random_uniform(arr.length);
+ memcpy(selected.data + i*padded_item_size, arr.data + arr.stride*index, (size_t)padded_item_size);
+ }
+ return selected;
+ }
+
+ if (weights.length != arr.length)
+ fail("Array has %ld elements, but there are %ld weights given", arr.length, weights.length);
+
double total = 0.0;
for (int64_t i = 0; i < weights.length && i < arr.length; i++) {
double weight = *(double*)(weights.data + weights.stride*i);
@@ -320,54 +337,50 @@ public Array_t Array$sample(Array_t arr, Int_t int_n, Array_t weights, int64_t p
if (isinf(total))
fail("Sample weights have overflowed to infinity");
- if (total == 0.0) {
- for (int64_t i = 0; i < n; i++) {
- int64_t index = arc4random_uniform(arr.length);
- memcpy(selected.data + i*padded_item_size, arr.data + arr.stride*index, (size_t)padded_item_size);
- }
- } else {
- double inverse_average = (double)arr.length / total;
+ if (total == 0.0)
+ fail("None of the given weights are nonzero");
- struct {
- int64_t alias;
- double odds;
- } aliases[arr.length] = {};
+ double inverse_average = (double)arr.length / total;
- for (int64_t i = 0; i < arr.length; i++) {
- double weight = i >= weights.length ? 0.0 : *(double*)(weights.data + weights.stride*i);
- aliases[i].odds = weight * inverse_average;
- aliases[i].alias = -1;
- }
+ struct {
+ int64_t alias;
+ double odds;
+ } aliases[arr.length] = {};
- int64_t small = 0;
- for (int64_t big = 0; big < arr.length; big++) {
- while (aliases[big].odds >= 1.0) {
- while (small < arr.length && (aliases[small].odds >= 1.0 || aliases[small].alias != -1))
- ++small;
+ for (int64_t i = 0; i < arr.length; i++) {
+ double weight = i >= weights.length ? 0.0 : *(double*)(weights.data + weights.stride*i);
+ aliases[i].odds = weight * inverse_average;
+ aliases[i].alias = -1;
+ }
- if (small >= arr.length) {
- aliases[big].odds = 1.0;
- aliases[big].alias = big;
- break;
- }
+ int64_t small = 0;
+ for (int64_t big = 0; big < arr.length; big++) {
+ while (aliases[big].odds >= 1.0) {
+ while (small < arr.length && (aliases[small].odds >= 1.0 || aliases[small].alias != -1))
+ ++small;
- aliases[small].alias = big;
- aliases[big].odds = (aliases[small].odds + aliases[big].odds) - 1.0;
+ if (small >= arr.length) {
+ aliases[big].odds = 1.0;
+ aliases[big].alias = big;
+ break;
}
- if (big < small) small = big;
+
+ aliases[small].alias = big;
+ aliases[big].odds = (aliases[small].odds + aliases[big].odds) - 1.0;
}
+ if (big < small) small = big;
+ }
- for (int64_t i = small; i < arr.length; i++)
- if (aliases[i].alias == -1)
- aliases[i].alias = i;
+ for (int64_t i = small; i < arr.length; i++)
+ if (aliases[i].alias == -1)
+ aliases[i].alias = i;
- for (int64_t i = 0; i < n; i++) {
- double r = drand48() * arr.length;
- int64_t index = (int64_t)r;
- if ((r - (double)index) > aliases[index].odds)
- index = aliases[index].alias;
- memcpy(selected.data + i*selected.stride, arr.data + index*arr.stride, (size_t)padded_item_size);
- }
+ for (int64_t i = 0; i < n; i++) {
+ double r = drand48() * arr.length;
+ int64_t index = (int64_t)r;
+ if ((r - (double)index) > aliases[index].odds)
+ index = aliases[index].alias;
+ memcpy(selected.data + i*selected.stride, arr.data + index*arr.stride, (size_t)padded_item_size);
}
return selected;
}