Update array:sample() to use optional weights and do more error checking

This commit is contained in:
Bruce Hill 2024-10-02 14:42:51 -04:00
parent c8c137639c
commit 54e336e30f
3 changed files with 68 additions and 48 deletions

View File

@ -2620,8 +2620,9 @@ CORD compile(env_t *env, ast_t *ast)
} else if (streq(call->name, "sample")) {
CORD self = compile_to_pointer_depth(env, call->self, 0, false);
arg_t *arg_spec = new(arg_t, .name="count", .type=INT_TYPE,
.next=new(arg_t, .name="weights", .type=Type(ArrayType, .item_type=Type(NumType)),
.default_val=FakeAST(Array, .item_type=new(type_ast_t, .tag=VarTypeAST, .__data.VarTypeAST.name="Num"))));
.next=new(arg_t, .name="weights", .type=Type(ArrayType, .item_type=Type(NumType)),
.default_val=FakeAST(Null, .type=new(type_ast_t, .tag=ArrayTypeAST,
.__data.ArrayTypeAST.item=new(type_ast_t, .tag=VarTypeAST, .__data.VarTypeAST.name="Num")))));
return CORD_all("Array$sample(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ",
padded_item_size, ")");
} else if (streq(call->name, "shuffle")) {

View File

@ -731,7 +731,7 @@ probabilities.
**Usage:**
```markdown
sample(arr: [T], count: Int, weights: [Num]) -> [T]
sample(arr: [T], count: Int, weights: [Num]? = ![Num]) -> [T]
```
**Parameters:**
@ -740,10 +740,16 @@ sample(arr: [T], count: Int, weights: [Num]) -> [T]
- `count`: The number of elements to sample.
- `weights`: The probability weights for each element in the array. These
values do not need to add up to any particular number, they are relative
weights. If no weights are provided, the default is equal probabilities.
Negative, infinite, or NaN weights will cause a runtime error. If the number of
weights given is less than the length of the array, elements from the rest of
the array are considered to have zero weight.
weights. If no weights are given, elements will be sampled with uniform
probability.
**Errors:**
Errors will be raised if any of the following conditions occurs:
- The given array has no elements and `count >= 1`
- `count < 0` (negative count)
- The number of weights provided doesn't match the length of the array.
- Any weight in the weights array is negative, infinite, or `NaN`
- The sum of the given weights is zero (zero probability for every element).
**Returns:**
A list of sampled elements from the array.

View File

@ -296,14 +296,31 @@ public Table_t Array$counts(Array_t arr, const TypeInfo_t *type)
public Array_t Array$sample(Array_t arr, Int_t int_n, Array_t weights, int64_t padded_item_size)
{
int64_t n = Int_to_Int64(int_n, false);
if (arr.length == 0 || n <= 0)
if (n < 0)
fail("Cannot select a negative number of values");
if (n == 0)
return (Array_t){};
if (arr.length == 0)
fail("There are no elements in this array!");
Array_t selected = {
.data=arr.atomic ? GC_MALLOC_ATOMIC((size_t)(n * padded_item_size)) : GC_MALLOC((size_t)(n * padded_item_size)),
.length=n,
.stride=padded_item_size, .atomic=arr.atomic};
if (weights.length < 0) {
for (int64_t i = 0; i < n; i++) {
int64_t index = arc4random_uniform(arr.length);
memcpy(selected.data + i*padded_item_size, arr.data + arr.stride*index, (size_t)padded_item_size);
}
return selected;
}
if (weights.length != arr.length)
fail("Array has %ld elements, but there are %ld weights given", arr.length, weights.length);
double total = 0.0;
for (int64_t i = 0; i < weights.length && i < arr.length; i++) {
double weight = *(double*)(weights.data + weights.stride*i);
@ -320,54 +337,50 @@ public Array_t Array$sample(Array_t arr, Int_t int_n, Array_t weights, int64_t p
if (isinf(total))
fail("Sample weights have overflowed to infinity");
if (total == 0.0) {
for (int64_t i = 0; i < n; i++) {
int64_t index = arc4random_uniform(arr.length);
memcpy(selected.data + i*padded_item_size, arr.data + arr.stride*index, (size_t)padded_item_size);
}
} else {
double inverse_average = (double)arr.length / total;
if (total == 0.0)
fail("None of the given weights are nonzero");
struct {
int64_t alias;
double odds;
} aliases[arr.length] = {};
double inverse_average = (double)arr.length / total;
for (int64_t i = 0; i < arr.length; i++) {
double weight = i >= weights.length ? 0.0 : *(double*)(weights.data + weights.stride*i);
aliases[i].odds = weight * inverse_average;
aliases[i].alias = -1;
}
struct {
int64_t alias;
double odds;
} aliases[arr.length] = {};
int64_t small = 0;
for (int64_t big = 0; big < arr.length; big++) {
while (aliases[big].odds >= 1.0) {
while (small < arr.length && (aliases[small].odds >= 1.0 || aliases[small].alias != -1))
++small;
for (int64_t i = 0; i < arr.length; i++) {
double weight = i >= weights.length ? 0.0 : *(double*)(weights.data + weights.stride*i);
aliases[i].odds = weight * inverse_average;
aliases[i].alias = -1;
}
if (small >= arr.length) {
aliases[big].odds = 1.0;
aliases[big].alias = big;
break;
}
int64_t small = 0;
for (int64_t big = 0; big < arr.length; big++) {
while (aliases[big].odds >= 1.0) {
while (small < arr.length && (aliases[small].odds >= 1.0 || aliases[small].alias != -1))
++small;
aliases[small].alias = big;
aliases[big].odds = (aliases[small].odds + aliases[big].odds) - 1.0;
if (small >= arr.length) {
aliases[big].odds = 1.0;
aliases[big].alias = big;
break;
}
if (big < small) small = big;
}
for (int64_t i = small; i < arr.length; i++)
if (aliases[i].alias == -1)
aliases[i].alias = i;
for (int64_t i = 0; i < n; i++) {
double r = drand48() * arr.length;
int64_t index = (int64_t)r;
if ((r - (double)index) > aliases[index].odds)
index = aliases[index].alias;
memcpy(selected.data + i*selected.stride, arr.data + index*arr.stride, (size_t)padded_item_size);
aliases[small].alias = big;
aliases[big].odds = (aliases[small].odds + aliases[big].odds) - 1.0;
}
if (big < small) small = big;
}
for (int64_t i = small; i < arr.length; i++)
if (aliases[i].alias == -1)
aliases[i].alias = i;
for (int64_t i = 0; i < n; i++) {
double r = drand48() * arr.length;
int64_t index = (int64_t)r;
if ((r - (double)index) > aliases[index].odds)
index = aliases[index].alias;
memcpy(selected.data + i*selected.stride, arr.data + index*arr.stride, (size_t)padded_item_size);
}
return selected;
}