diff --git a/builtins/array.c b/builtins/array.c index 4eea25f..333a2c5 100644 --- a/builtins/array.c +++ b/builtins/array.c @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -181,6 +182,85 @@ public void *Array$random(array_t arr) return arr.data + arr.stride*index; } +public array_t Array$sample(array_t arr, int64_t n, array_t weights, const TypeInfo *type) +{ + if (arr.length == 0 || n <= 0) + return (array_t){}; + + int64_t item_size = get_item_size(type); + array_t selected = { + .data=arr.atomic ? GC_MALLOC_ATOMIC(n * item_size) : GC_MALLOC(n * item_size), + .length=n, + .stride=item_size, .atomic=arr.atomic}; + + double total = 0.0; + for (int64_t i = 0; i < weights.length && i < arr.length; i++) { + double weight = *(double*)(weights.data + weights.stride*i); + if (isinf(weight)) + fail("Infinite weight!"); + else if (isnan(weight)) + fail("NaN weight!"); + else if (weight < 0.0) + fail("Negative weight!"); + else + total += weight; + } + + if (isinf(total)) + fail("Sample weights have overflowed to infinity"); + + if (total == 0.0) { + for (int64_t i = 0; i < n; i++) { + uint32_t index = arc4random_uniform(arr.length); + memcpy(selected.data + i*item_size, arr.data + arr.stride*index, item_size); + } + } else { + double inverse_average = (double)arr.length / total; + + struct { + int64_t alias; + double odds; + } aliases[arr.length] = {}; + + for (int64_t i = 0; i < arr.length; i++) { + double weight = i >= weights.length ? 0.0 : *(double*)(weights.data + weights.stride*i); + aliases[i].odds = weight * inverse_average; + aliases[i].alias = -1; + } + + int64_t small = 0; + for (int64_t big = 0; big < arr.length; big++) { + while (aliases[big].odds >= 1.0) { + while (small < arr.length && (aliases[small].odds >= 1.0 || aliases[small].alias != -1)) + ++small; + + if (small >= arr.length) { + aliases[big].odds = 1.0; + aliases[big].alias = big; + break; + } + + aliases[small].alias = big; + aliases[big].odds = (aliases[small].odds + aliases[big].odds) - 1.0; + } + if (big < small) small = big; + } + + for (int64_t i = small; i < arr.length; i++) + if (aliases[i].alias == -1) + aliases[i].alias = i; + + for (int64_t i = 0; i < n; i++) { + double r = drand48() * arr.length; + int64_t index = (int64_t)r; + if ((r - (double)index) > aliases[index].odds) + index = aliases[index].alias; + memcpy(selected.data + i*selected.stride, arr.data + index*arr.stride, item_size); + } + } + return selected; +} + public array_t Array$slice(array_t *array, int64_t first, int64_t length, int64_t stride, const TypeInfo *type) { if (stride > MAX_STRIDE || stride < MIN_STRIDE) diff --git a/builtins/array.h b/builtins/array.h index c6722e6..1875785 100644 --- a/builtins/array.h +++ b/builtins/array.h @@ -60,6 +60,7 @@ void Array$sort(array_t *arr, closure_t comparison, const TypeInfo *type); array_t Array$sorted(array_t arr, closure_t comparison, const TypeInfo *type); void Array$shuffle(array_t *arr, const TypeInfo *type); void *Array$random(array_t arr); +array_t Array$sample(array_t arr, int64_t n, array_t weights, const TypeInfo *type); void Array$clear(array_t *array); void Array$compact(array_t *arr, const TypeInfo *type); bool Array$contains(array_t array, void *item, const TypeInfo *type); diff --git a/compile.c b/compile.c index ff6a868..08e2fe4 100644 --- a/compile.c +++ b/compile.c @@ -1481,6 +1481,12 @@ CORD compile(env_t *env, ast_t *ast) CORD self = compile_to_pointer_depth(env, call->self, 0, false); (void)compile_arguments(env, ast, NULL, call->args); return CORD_all("Array$random(", self, ")"); + } else if (streq(call->name, "sample")) { + CORD self = compile_to_pointer_depth(env, call->self, 0, false); + arg_t *arg_spec = new(arg_t, .name="count", .type=Type(IntType, .bits=64), + .next=new(arg_t, .name="weights", .type=self_value_t, .default_val=FakeAST(Array))); + return CORD_all("Array$sample(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ", + compile_type_info(env, self_value_t), ")"); } else if (streq(call->name, "shuffle")) { CORD self = compile_to_pointer_depth(env, call->self, 1, false); (void)compile_arguments(env, ast, NULL, call->args); diff --git a/typecheck.c b/typecheck.c index 93b105b..35bdef1 100644 --- a/typecheck.c +++ b/typecheck.c @@ -509,6 +509,7 @@ type_t *get_type(env_t *env, ast_t *ast) else if (streq(call->name, "shuffle")) return Type(VoidType); else if (streq(call->name, "random")) return Type(PointerType, .pointed=Match(self_value_t, ArrayType)->item_type, .is_optional=true, .is_readonly=true); + else if (streq(call->name, "sample")) return self_value_t; else if (streq(call->name, "clear")) return Type(VoidType); else if (streq(call->name, "slice")) return self_value_t; else if (streq(call->name, "reversed")) return self_value_t;