Update array:sample() to use optional weights and do more error checking
This commit is contained in:
parent
c8c137639c
commit
54e336e30f
@ -2620,8 +2620,9 @@ CORD compile(env_t *env, ast_t *ast)
|
||||
} else if (streq(call->name, "sample")) {
|
||||
CORD self = compile_to_pointer_depth(env, call->self, 0, false);
|
||||
arg_t *arg_spec = new(arg_t, .name="count", .type=INT_TYPE,
|
||||
.next=new(arg_t, .name="weights", .type=Type(ArrayType, .item_type=Type(NumType)),
|
||||
.default_val=FakeAST(Array, .item_type=new(type_ast_t, .tag=VarTypeAST, .__data.VarTypeAST.name="Num"))));
|
||||
.next=new(arg_t, .name="weights", .type=Type(ArrayType, .item_type=Type(NumType)),
|
||||
.default_val=FakeAST(Null, .type=new(type_ast_t, .tag=ArrayTypeAST,
|
||||
.__data.ArrayTypeAST.item=new(type_ast_t, .tag=VarTypeAST, .__data.VarTypeAST.name="Num")))));
|
||||
return CORD_all("Array$sample(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ",
|
||||
padded_item_size, ")");
|
||||
} else if (streq(call->name, "shuffle")) {
|
||||
|
@ -731,7 +731,7 @@ probabilities.
|
||||
|
||||
**Usage:**
|
||||
```markdown
|
||||
sample(arr: [T], count: Int, weights: [Num]) -> [T]
|
||||
sample(arr: [T], count: Int, weights: [Num]? = ![Num]) -> [T]
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
@ -740,10 +740,16 @@ sample(arr: [T], count: Int, weights: [Num]) -> [T]
|
||||
- `count`: The number of elements to sample.
|
||||
- `weights`: The probability weights for each element in the array. These
|
||||
values do not need to add up to any particular number, they are relative
|
||||
weights. If no weights are provided, the default is equal probabilities.
|
||||
Negative, infinite, or NaN weights will cause a runtime error. If the number of
|
||||
weights given is less than the length of the array, elements from the rest of
|
||||
the array are considered to have zero weight.
|
||||
weights. If no weights are given, elements will be sampled with uniform
|
||||
probability.
|
||||
|
||||
**Errors:**
|
||||
Errors will be raised if any of the following conditions occurs:
|
||||
- The given array has no elements and `count >= 1`
|
||||
- `count < 0` (negative count)
|
||||
- The number of weights provided doesn't match the length of the array.
|
||||
- Any weight in the weights array is negative, infinite, or `NaN`
|
||||
- The sum of the given weights is zero (zero probability for every element).
|
||||
|
||||
**Returns:**
|
||||
A list of sampled elements from the array.
|
||||
|
@ -296,14 +296,31 @@ public Table_t Array$counts(Array_t arr, const TypeInfo_t *type)
|
||||
public Array_t Array$sample(Array_t arr, Int_t int_n, Array_t weights, int64_t padded_item_size)
|
||||
{
|
||||
int64_t n = Int_to_Int64(int_n, false);
|
||||
if (arr.length == 0 || n <= 0)
|
||||
if (n < 0)
|
||||
fail("Cannot select a negative number of values");
|
||||
|
||||
if (n == 0)
|
||||
return (Array_t){};
|
||||
|
||||
if (arr.length == 0)
|
||||
fail("There are no elements in this array!");
|
||||
|
||||
Array_t selected = {
|
||||
.data=arr.atomic ? GC_MALLOC_ATOMIC((size_t)(n * padded_item_size)) : GC_MALLOC((size_t)(n * padded_item_size)),
|
||||
.length=n,
|
||||
.stride=padded_item_size, .atomic=arr.atomic};
|
||||
|
||||
if (weights.length < 0) {
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
int64_t index = arc4random_uniform(arr.length);
|
||||
memcpy(selected.data + i*padded_item_size, arr.data + arr.stride*index, (size_t)padded_item_size);
|
||||
}
|
||||
return selected;
|
||||
}
|
||||
|
||||
if (weights.length != arr.length)
|
||||
fail("Array has %ld elements, but there are %ld weights given", arr.length, weights.length);
|
||||
|
||||
double total = 0.0;
|
||||
for (int64_t i = 0; i < weights.length && i < arr.length; i++) {
|
||||
double weight = *(double*)(weights.data + weights.stride*i);
|
||||
@ -320,54 +337,50 @@ public Array_t Array$sample(Array_t arr, Int_t int_n, Array_t weights, int64_t p
|
||||
if (isinf(total))
|
||||
fail("Sample weights have overflowed to infinity");
|
||||
|
||||
if (total == 0.0) {
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
int64_t index = arc4random_uniform(arr.length);
|
||||
memcpy(selected.data + i*padded_item_size, arr.data + arr.stride*index, (size_t)padded_item_size);
|
||||
}
|
||||
} else {
|
||||
double inverse_average = (double)arr.length / total;
|
||||
if (total == 0.0)
|
||||
fail("None of the given weights are nonzero");
|
||||
|
||||
struct {
|
||||
int64_t alias;
|
||||
double odds;
|
||||
} aliases[arr.length] = {};
|
||||
double inverse_average = (double)arr.length / total;
|
||||
|
||||
for (int64_t i = 0; i < arr.length; i++) {
|
||||
double weight = i >= weights.length ? 0.0 : *(double*)(weights.data + weights.stride*i);
|
||||
aliases[i].odds = weight * inverse_average;
|
||||
aliases[i].alias = -1;
|
||||
}
|
||||
struct {
|
||||
int64_t alias;
|
||||
double odds;
|
||||
} aliases[arr.length] = {};
|
||||
|
||||
int64_t small = 0;
|
||||
for (int64_t big = 0; big < arr.length; big++) {
|
||||
while (aliases[big].odds >= 1.0) {
|
||||
while (small < arr.length && (aliases[small].odds >= 1.0 || aliases[small].alias != -1))
|
||||
++small;
|
||||
for (int64_t i = 0; i < arr.length; i++) {
|
||||
double weight = i >= weights.length ? 0.0 : *(double*)(weights.data + weights.stride*i);
|
||||
aliases[i].odds = weight * inverse_average;
|
||||
aliases[i].alias = -1;
|
||||
}
|
||||
|
||||
if (small >= arr.length) {
|
||||
aliases[big].odds = 1.0;
|
||||
aliases[big].alias = big;
|
||||
break;
|
||||
}
|
||||
int64_t small = 0;
|
||||
for (int64_t big = 0; big < arr.length; big++) {
|
||||
while (aliases[big].odds >= 1.0) {
|
||||
while (small < arr.length && (aliases[small].odds >= 1.0 || aliases[small].alias != -1))
|
||||
++small;
|
||||
|
||||
aliases[small].alias = big;
|
||||
aliases[big].odds = (aliases[small].odds + aliases[big].odds) - 1.0;
|
||||
if (small >= arr.length) {
|
||||
aliases[big].odds = 1.0;
|
||||
aliases[big].alias = big;
|
||||
break;
|
||||
}
|
||||
if (big < small) small = big;
|
||||
}
|
||||
|
||||
for (int64_t i = small; i < arr.length; i++)
|
||||
if (aliases[i].alias == -1)
|
||||
aliases[i].alias = i;
|
||||
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
double r = drand48() * arr.length;
|
||||
int64_t index = (int64_t)r;
|
||||
if ((r - (double)index) > aliases[index].odds)
|
||||
index = aliases[index].alias;
|
||||
memcpy(selected.data + i*selected.stride, arr.data + index*arr.stride, (size_t)padded_item_size);
|
||||
aliases[small].alias = big;
|
||||
aliases[big].odds = (aliases[small].odds + aliases[big].odds) - 1.0;
|
||||
}
|
||||
if (big < small) small = big;
|
||||
}
|
||||
|
||||
for (int64_t i = small; i < arr.length; i++)
|
||||
if (aliases[i].alias == -1)
|
||||
aliases[i].alias = i;
|
||||
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
double r = drand48() * arr.length;
|
||||
int64_t index = (int64_t)r;
|
||||
if ((r - (double)index) > aliases[index].odds)
|
||||
index = aliases[index].alias;
|
||||
memcpy(selected.data + i*selected.stride, arr.data + index*arr.stride, (size_t)padded_item_size);
|
||||
}
|
||||
return selected;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user