Add heapify(), heap_push(), and heap_pop()

This commit is contained in:
Bruce Hill 2024-04-19 13:29:04 -04:00
parent 072bd523b9
commit 3b0dce04a0
5 changed files with 247 additions and 0 deletions

View File

@ -453,5 +453,194 @@ public uint32_t Array$hash(const array_t *arr, const TypeInfo *type)
return hash;
}
}
/*
def _siftdown(heap, startpos, pos):
newitem = heap[pos]
# Follow the path to the root, moving parents down until finding a place
# newitem fits.
while pos > startpos:
parentpos = (pos - 1) >> 1
parent = heap[parentpos]
if newitem < parent:
heap[pos] = parent
pos = parentpos
continue
break
heap[pos] = newitem
*/
static int siftdown(array_t *heap, int64_t startpos, int64_t pos, closure_t comparison, const TypeInfo *type)
{
assert(pos > 0 && pos < heap->length);
int64_t item_size = get_item_size(type);
char newitem[item_size];
memcpy(newitem, heap->data + heap->stride*pos, item_size);
while (pos > startpos) {
int64_t parentpos = (pos - 1) >> 1;
typedef int32_t (*cmp_fn_t)(void*, void*, void*);
int32_t cmp = ((cmp_fn_t)comparison.fn)(newitem, heap->data + heap->stride*parentpos, comparison.userdata);
if (cmp >= 0)
break;
memcpy(newitem, heap->data + heap->stride*pos, item_size);
// swap pos/parentpos:
memcpy(heap->data + heap->stride*pos, heap->data + heap->stride*parentpos, item_size);
memcpy(heap->data + heap->stride*parentpos, newitem, item_size);
pos = parentpos;
}
return 0;
}
static int siftup(array_t *heap, int64_t pos, closure_t comparison, const TypeInfo *type)
{
int64_t endpos = heap->length;
int64_t startpos = pos;
assert(pos < endpos);
int64_t item_size = get_item_size(type);
/* Bubble up the smaller child until hitting a leaf. */
int64_t limit = endpos >> 1; /* smallest pos that has no child */
while (pos < limit) {
/* Set childpos to index of smaller child. */
int64_t childpos = 2*pos + 1; /* leftmost child position */
if (childpos + 1 < endpos) {
typedef int32_t (*cmp_fn_t)(void*, void*, void*);
int32_t cmp = ((cmp_fn_t)comparison.fn)(
heap->data + heap->stride*childpos,
heap->data + heap->stride*(childpos + 1),
comparison.userdata);
// if (cmp < 0)
// return -1;
childpos += (cmp >= 0);
}
if (heap->data_refcount >= 3) {
Array$compact(heap, type);
heap->data_refcount = 1;
}
/* Move the smaller child up. */
char buf[item_size];
memcpy(buf, heap->data + heap->stride*pos, item_size);
memcpy(heap->data + heap->stride*pos, heap->data + heap->stride*childpos, item_size);
memcpy(heap->data + heap->stride*childpos, buf, item_size);
pos = childpos;
}
/* Bubble it up to its final resting place (by sifting its parents down). */
return siftdown(heap, startpos, pos, comparison, type);
}
public void Array$heap_push(array_t *heap, const void *item, closure_t comparison, const TypeInfo *type)
{
Array$insert(heap, item, 0, type);
if (heap->data_refcount > 0)
Array$compact(heap, type);
if (heap->length > 1)
siftdown(heap, 0, heap->length-1, comparison, type);
}
public void Array$heap_pop(array_t *heap, void *out, closure_t comparison, const TypeInfo *type)
{
if (heap->length == 0)
fail("Attempt to pop from an empty array");
int64_t item_size = get_item_size(type);
memcpy(out, heap->data, item_size);
if (heap->length > 1) {
if (heap->data_refcount > 0)
Array$compact(heap, type);
memcpy(heap->data, heap->data + heap->stride*(heap->length-1), item_size);
--heap->length;
if (heap->length > 1)
siftup(heap, 0, comparison, type);
} else {
*heap = (array_t){};
}
}
static int64_t
keep_top_bit(int64_t n)
{
int i = 0;
for (; n > 1; n >>= 1)
++i;
return n << i;
}
public void Array$heapify(array_t *heap, closure_t comparison, const TypeInfo *type)
{
if (heap->data_refcount > 0)
Array$compact(heap, type);
ARRAY_INCREF(*heap);
/* For heaps likely to be bigger than L1 cache, we use the cache
friendly heapify function. For smaller heaps that fit entirely
in cache, we prefer the simpler algorithm with less branching.
*/
if (heap->length <= 2500) {
/* Transform bottom-up. The largest index there's any point to
looking at is the largest with a child index in-range, so must
have 2*i + 1 < n, or i < (n-1)/2. If n is even = 2*j, this is
(2*j-1)/2 = j-1/2 so j-1 is the largest, which is n//2 - 1. If
n is odd = 2*j+1, this is (2*j+1-1)/2 = j so j-1 is the largest,
and that's again n//2-1.
*/
int64_t i, n = heap->length;
for (i = (n >> 1) - 1 ; i >= 0 ; i--)
if (siftup(heap, i, comparison, type))
goto cleanup;
} else {
/* Cache friendly version of heapify()
-----------------------------------
Build-up a heap in O(n) time by performing siftup() operations
on nodes whose children are already heaps.
The simplest way is to sift the nodes in reverse order from
n//2-1 to 0 inclusive. The downside is that children may be
out of cache by the time their parent is reached.
A better way is to not wait for the children to go out of cache.
Once a sibling pair of child nodes have been sifted, immediately
sift their parent node (while the children are still in cache).
Both ways build child heaps before their parents, so both ways
do the exact same number of comparisons and produce exactly
the same heap. The only difference is that the traversal
order is optimized for cache efficiency.
*/
int64_t m = heap->length >> 1; /* index of first childless node */
int64_t leftmost = keep_top_bit(m + 1) - 1; /* leftmost node in row of m */
int64_t mhalf = m >> 1; /* parent of first childless node */
int64_t i;
for (i = leftmost - 1 ; i >= mhalf ; i--) {
int64_t j = i;
while (1) {
if (siftup(heap, j, comparison, type))
goto cleanup;
if (!(j & 1))
break;
j >>= 1;
}
}
for (i = m - 1 ; i >= leftmost ; i--) {
int64_t j = i;
while (1) {
if (siftup(heap, j, comparison, type))
goto cleanup;
if (!(j & 1))
break;
j >>= 1;
}
}
}
cleanup:
ARRAY_DECREF(*heap);
}
// vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0

View File

@ -71,5 +71,10 @@ uint32_t Array$hash(const array_t *arr, const TypeInfo *type);
int32_t Array$compare(const array_t *x, const array_t *y, const TypeInfo *type);
bool Array$equal(const array_t *x, const array_t *y, const TypeInfo *type);
CORD Array$as_text(const array_t *arr, bool colorize, const TypeInfo *type);
void Array$heapify(array_t *heap, closure_t comparison, const TypeInfo *type);
void Array$heap_push(array_t *heap, const void *item, closure_t comparison, const TypeInfo *type);
#define Array$heap_push_value(heap, _value, comparison, typeinfo) ({ __typeof(_value) value = _value; Array$heap_push(heap, &value, comparison, typeinfo); })
void Array$heap_pop(array_t *heap, void *out, closure_t comparison, const TypeInfo *type);
#define Array$heap_pop_value(heap, comparison, typeinfo, type) ({ type value; Array$heap_pop(heap, &value, comparison, typeinfo); value; })
// vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0

View File

@ -1515,6 +1515,37 @@ CORD compile(env_t *env, ast_t *ast)
comparison = CORD_all("(closure_t){.fn=generic_compare, .userdata=(void*)", compile_type_info(env, item_t), "}");
}
return CORD_all("Array$", call->name, "(", self, ", ", comparison, ", ", compile_type_info(env, self_value_t), ")");
} else if (streq(call->name, "heapify")) {
CORD self = compile_to_pointer_depth(env, call->self, 1, false);
CORD comparison;
if (call->args) {
type_t *item_ptr = Type(PointerType, .pointed=item_t, .is_stack=true);
type_t *fn_t = Type(FunctionType, .args=new(arg_t, .name="x", .type=item_ptr, .next=new(arg_t, .name="y", .type=item_ptr)),
.ret=Type(IntType, .bits=32));
arg_t *arg_spec = new(arg_t, .name="by", .type=Type(ClosureType, .fn=fn_t));
comparison = compile_arguments(env, ast, arg_spec, call->args);
} else {
comparison = CORD_all("((closure_t){.fn=generic_compare, .userdata=(void*)", compile_type_info(env, item_t), "})");
}
return CORD_all("Array$heapify(", self, ", ", comparison, ", ", compile_type_info(env, self_value_t), ")");
} else if (streq(call->name, "heap_push")) {
CORD self = compile_to_pointer_depth(env, call->self, 1, false);
type_t *item_ptr = Type(PointerType, .pointed=item_t, .is_stack=true);
type_t *fn_t = Type(FunctionType, .args=new(arg_t, .name="x", .type=item_ptr, .next=new(arg_t, .name="y", .type=item_ptr)),
.ret=Type(IntType, .bits=32));
ast_t *default_cmp = FakeAST(InlineCCode, CORD_all("((closure_t){.fn=generic_compare, .userdata=(void*)", compile_type_info(env, item_t), "})"));
arg_t *arg_spec = new(arg_t, .name="item", .type=item_t, .next=new(arg_t, .name="by", .type=Type(ClosureType, .fn=fn_t), .default_val=default_cmp));
CORD arg_code = compile_arguments(env, ast, arg_spec, call->args);
return CORD_all("Array$heap_push_value(", self, ", ", arg_code, ", ", compile_type_info(env, self_value_t), ")");
} else if (streq(call->name, "heap_pop")) {
CORD self = compile_to_pointer_depth(env, call->self, 1, false);
type_t *item_ptr = Type(PointerType, .pointed=item_t, .is_stack=true);
type_t *fn_t = Type(FunctionType, .args=new(arg_t, .name="x", .type=item_ptr, .next=new(arg_t, .name="y", .type=item_ptr)),
.ret=Type(IntType, .bits=32));
ast_t *default_cmp = FakeAST(InlineCCode, CORD_all("((closure_t){.fn=generic_compare, .userdata=(void*)", compile_type_info(env, item_t), "})"));
arg_t *arg_spec = new(arg_t, .name="by", .type=Type(ClosureType, .fn=fn_t), .default_val=default_cmp);
CORD arg_code = compile_arguments(env, ast, arg_spec, call->args);
return CORD_all("Array$heap_pop_value(", self, ", ", arg_code, ", ", compile_type_info(env, self_value_t), ", ", compile_type(env, item_t), ")");
} else if (streq(call->name, "clear")) {
CORD self = compile_to_pointer_depth(env, call->self, 1, false);
(void)compile_arguments(env, ast, NULL, call->args);

View File

@ -95,3 +95,22 @@ func main()
= [30, 10, -20]
>> ["A", "B", "C"]:sample(10, [1.0, 0.5, 0.0])
if yes
>> heap := [Int.random(max=50) for _ in 10]
>> heap:heapify()
>> heap
sorted := [:Int]
while #heap > 0
sorted:insert(heap:heap_pop())
>> sorted == sorted:sorted()
= yes
for _ in 10
heap:heap_push(Int.random(max=50))
>> heap
sorted = [:Int]
while #heap > 0
sorted:insert(heap:heap_pop())
>> sorted == sorted:sorted()
= yes

View File

@ -514,6 +514,9 @@ type_t *get_type(env_t *env, ast_t *ast)
else if (streq(call->name, "clear")) return Type(VoidType);
else if (streq(call->name, "slice")) return self_value_t;
else if (streq(call->name, "reversed")) return self_value_t;
else if (streq(call->name, "heapify")) return Type(VoidType);
else if (streq(call->name, "heap_push")) return Type(VoidType);
else if (streq(call->name, "heap_pop")) return Match(self_value_t, ArrayType)->item_type;
else code_err(ast, "There is no '%s' method for arrays", call->name);
}
case TableType: {