Add array:sample()
This commit is contained in:
parent
95100469b6
commit
fae2b2caa0
@ -4,6 +4,7 @@
|
||||
#include <err.h>
|
||||
#include <gc.h>
|
||||
#include <gc/cord.h>
|
||||
#include <math.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
@ -181,6 +182,85 @@ public void *Array$random(array_t arr)
|
||||
return arr.data + arr.stride*index;
|
||||
}
|
||||
|
||||
public array_t Array$sample(array_t arr, int64_t n, array_t weights, const TypeInfo *type)
|
||||
{
|
||||
if (arr.length == 0 || n <= 0)
|
||||
return (array_t){};
|
||||
|
||||
int64_t item_size = get_item_size(type);
|
||||
array_t selected = {
|
||||
.data=arr.atomic ? GC_MALLOC_ATOMIC(n * item_size) : GC_MALLOC(n * item_size),
|
||||
.length=n,
|
||||
.stride=item_size, .atomic=arr.atomic};
|
||||
|
||||
double total = 0.0;
|
||||
for (int64_t i = 0; i < weights.length && i < arr.length; i++) {
|
||||
double weight = *(double*)(weights.data + weights.stride*i);
|
||||
if (isinf(weight))
|
||||
fail("Infinite weight!");
|
||||
else if (isnan(weight))
|
||||
fail("NaN weight!");
|
||||
else if (weight < 0.0)
|
||||
fail("Negative weight!");
|
||||
else
|
||||
total += weight;
|
||||
}
|
||||
|
||||
if (isinf(total))
|
||||
fail("Sample weights have overflowed to infinity");
|
||||
|
||||
if (total == 0.0) {
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
uint32_t index = arc4random_uniform(arr.length);
|
||||
memcpy(selected.data + i*item_size, arr.data + arr.stride*index, item_size);
|
||||
}
|
||||
} else {
|
||||
double inverse_average = (double)arr.length / total;
|
||||
|
||||
struct {
|
||||
int64_t alias;
|
||||
double odds;
|
||||
} aliases[arr.length] = {};
|
||||
|
||||
for (int64_t i = 0; i < arr.length; i++) {
|
||||
double weight = i >= weights.length ? 0.0 : *(double*)(weights.data + weights.stride*i);
|
||||
aliases[i].odds = weight * inverse_average;
|
||||
aliases[i].alias = -1;
|
||||
}
|
||||
|
||||
int64_t small = 0;
|
||||
for (int64_t big = 0; big < arr.length; big++) {
|
||||
while (aliases[big].odds >= 1.0) {
|
||||
while (small < arr.length && (aliases[small].odds >= 1.0 || aliases[small].alias != -1))
|
||||
++small;
|
||||
|
||||
if (small >= arr.length) {
|
||||
aliases[big].odds = 1.0;
|
||||
aliases[big].alias = big;
|
||||
break;
|
||||
}
|
||||
|
||||
aliases[small].alias = big;
|
||||
aliases[big].odds = (aliases[small].odds + aliases[big].odds) - 1.0;
|
||||
}
|
||||
if (big < small) small = big;
|
||||
}
|
||||
|
||||
for (int64_t i = small; i < arr.length; i++)
|
||||
if (aliases[i].alias == -1)
|
||||
aliases[i].alias = i;
|
||||
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
double r = drand48() * arr.length;
|
||||
int64_t index = (int64_t)r;
|
||||
if ((r - (double)index) > aliases[index].odds)
|
||||
index = aliases[index].alias;
|
||||
memcpy(selected.data + i*selected.stride, arr.data + index*arr.stride, item_size);
|
||||
}
|
||||
}
|
||||
return selected;
|
||||
}
|
||||
|
||||
public array_t Array$slice(array_t *array, int64_t first, int64_t length, int64_t stride, const TypeInfo *type)
|
||||
{
|
||||
if (stride > MAX_STRIDE || stride < MIN_STRIDE)
|
||||
|
@ -60,6 +60,7 @@ void Array$sort(array_t *arr, closure_t comparison, const TypeInfo *type);
|
||||
array_t Array$sorted(array_t arr, closure_t comparison, const TypeInfo *type);
|
||||
void Array$shuffle(array_t *arr, const TypeInfo *type);
|
||||
void *Array$random(array_t arr);
|
||||
array_t Array$sample(array_t arr, int64_t n, array_t weights, const TypeInfo *type);
|
||||
void Array$clear(array_t *array);
|
||||
void Array$compact(array_t *arr, const TypeInfo *type);
|
||||
bool Array$contains(array_t array, void *item, const TypeInfo *type);
|
||||
|
@ -1481,6 +1481,12 @@ CORD compile(env_t *env, ast_t *ast)
|
||||
CORD self = compile_to_pointer_depth(env, call->self, 0, false);
|
||||
(void)compile_arguments(env, ast, NULL, call->args);
|
||||
return CORD_all("Array$random(", self, ")");
|
||||
} else if (streq(call->name, "sample")) {
|
||||
CORD self = compile_to_pointer_depth(env, call->self, 0, false);
|
||||
arg_t *arg_spec = new(arg_t, .name="count", .type=Type(IntType, .bits=64),
|
||||
.next=new(arg_t, .name="weights", .type=self_value_t, .default_val=FakeAST(Array)));
|
||||
return CORD_all("Array$sample(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ",
|
||||
compile_type_info(env, self_value_t), ")");
|
||||
} else if (streq(call->name, "shuffle")) {
|
||||
CORD self = compile_to_pointer_depth(env, call->self, 1, false);
|
||||
(void)compile_arguments(env, ast, NULL, call->args);
|
||||
|
@ -509,6 +509,7 @@ type_t *get_type(env_t *env, ast_t *ast)
|
||||
else if (streq(call->name, "shuffle")) return Type(VoidType);
|
||||
else if (streq(call->name, "random"))
|
||||
return Type(PointerType, .pointed=Match(self_value_t, ArrayType)->item_type, .is_optional=true, .is_readonly=true);
|
||||
else if (streq(call->name, "sample")) return self_value_t;
|
||||
else if (streq(call->name, "clear")) return Type(VoidType);
|
||||
else if (streq(call->name, "slice")) return self_value_t;
|
||||
else if (streq(call->name, "reversed")) return self_value_t;
|
||||
|
Loading…
Reference in New Issue
Block a user