diff options
| author | Bruce Hill <bruce@bruce-hill.com> | 2024-04-19 13:29:04 -0400 |
|---|---|---|
| committer | Bruce Hill <bruce@bruce-hill.com> | 2024-04-19 13:29:04 -0400 |
| commit | 3b0dce04a08d864626f841383727fbdab339ec83 (patch) | |
| tree | c466e7104ff75c76394fd4796cb0752b4f2c6a13 | |
| parent | 072bd523b97aacaf8639dd89a49f0c1a16d1d405 (diff) | |
Add heapify(), heap_push(), and heap_pop()
| -rw-r--r-- | builtins/array.c | 189 | ||||
| -rw-r--r-- | builtins/array.h | 5 | ||||
| -rw-r--r-- | compile.c | 31 | ||||
| -rw-r--r-- | test/arrays.tm | 19 | ||||
| -rw-r--r-- | typecheck.c | 3 |
5 files changed, 247 insertions, 0 deletions
diff --git a/builtins/array.c b/builtins/array.c index 5d8d5848..0c629610 100644 --- a/builtins/array.c +++ b/builtins/array.c @@ -453,5 +453,194 @@ public uint32_t Array$hash(const array_t *arr, const TypeInfo *type) return hash; } } +/* +def _siftdown(heap, startpos, pos): + newitem = heap[pos] + # Follow the path to the root, moving parents down until finding a place + # newitem fits. + while pos > startpos: + parentpos = (pos - 1) >> 1 + parent = heap[parentpos] + if newitem < parent: + heap[pos] = parent + pos = parentpos + continue + break + heap[pos] = newitem + */ + +static int siftdown(array_t *heap, int64_t startpos, int64_t pos, closure_t comparison, const TypeInfo *type) +{ + assert(pos > 0 && pos < heap->length); + int64_t item_size = get_item_size(type); + char newitem[item_size]; + memcpy(newitem, heap->data + heap->stride*pos, item_size); + while (pos > startpos) { + int64_t parentpos = (pos - 1) >> 1; + typedef int32_t (*cmp_fn_t)(void*, void*, void*); + int32_t cmp = ((cmp_fn_t)comparison.fn)(newitem, heap->data + heap->stride*parentpos, comparison.userdata); + if (cmp >= 0) + break; + + memcpy(newitem, heap->data + heap->stride*pos, item_size); + // swap pos/parentpos: + memcpy(heap->data + heap->stride*pos, heap->data + heap->stride*parentpos, item_size); + memcpy(heap->data + heap->stride*parentpos, newitem, item_size); + + pos = parentpos; + } + return 0; +} + +static int siftup(array_t *heap, int64_t pos, closure_t comparison, const TypeInfo *type) +{ + int64_t endpos = heap->length; + int64_t startpos = pos; + assert(pos < endpos); + + int64_t item_size = get_item_size(type); + /* Bubble up the smaller child until hitting a leaf. */ + int64_t limit = endpos >> 1; /* smallest pos that has no child */ + while (pos < limit) { + /* Set childpos to index of smaller child. */ + int64_t childpos = 2*pos + 1; /* leftmost child position */ + if (childpos + 1 < endpos) { + typedef int32_t (*cmp_fn_t)(void*, void*, void*); + int32_t cmp = ((cmp_fn_t)comparison.fn)( + heap->data + heap->stride*childpos, + heap->data + heap->stride*(childpos + 1), + comparison.userdata); + // if (cmp < 0) + // return -1; + childpos += (cmp >= 0); + } + + if (heap->data_refcount >= 3) { + Array$compact(heap, type); + heap->data_refcount = 1; + } + + /* Move the smaller child up. */ + char buf[item_size]; + memcpy(buf, heap->data + heap->stride*pos, item_size); + memcpy(heap->data + heap->stride*pos, heap->data + heap->stride*childpos, item_size); + memcpy(heap->data + heap->stride*childpos, buf, item_size); + + pos = childpos; + } + /* Bubble it up to its final resting place (by sifting its parents down). */ + return siftdown(heap, startpos, pos, comparison, type); +} + +public void Array$heap_push(array_t *heap, const void *item, closure_t comparison, const TypeInfo *type) +{ + Array$insert(heap, item, 0, type); + + if (heap->data_refcount > 0) + Array$compact(heap, type); + + if (heap->length > 1) + siftdown(heap, 0, heap->length-1, comparison, type); +} + +public void Array$heap_pop(array_t *heap, void *out, closure_t comparison, const TypeInfo *type) +{ + if (heap->length == 0) + fail("Attempt to pop from an empty array"); + + int64_t item_size = get_item_size(type); + memcpy(out, heap->data, item_size); + if (heap->length > 1) { + if (heap->data_refcount > 0) + Array$compact(heap, type); + memcpy(heap->data, heap->data + heap->stride*(heap->length-1), item_size); + --heap->length; + if (heap->length > 1) + siftup(heap, 0, comparison, type); + } else { + *heap = (array_t){}; + } +} + +static int64_t +keep_top_bit(int64_t n) +{ + int i = 0; + for (; n > 1; n >>= 1) + ++i; + return n << i; +} + +public void Array$heapify(array_t *heap, closure_t comparison, const TypeInfo *type) +{ + if (heap->data_refcount > 0) + Array$compact(heap, type); + + ARRAY_INCREF(*heap); + /* For heaps likely to be bigger than L1 cache, we use the cache + friendly heapify function. For smaller heaps that fit entirely + in cache, we prefer the simpler algorithm with less branching. + */ + if (heap->length <= 2500) { + /* Transform bottom-up. The largest index there's any point to + looking at is the largest with a child index in-range, so must + have 2*i + 1 < n, or i < (n-1)/2. If n is even = 2*j, this is + (2*j-1)/2 = j-1/2 so j-1 is the largest, which is n//2 - 1. If + n is odd = 2*j+1, this is (2*j+1-1)/2 = j so j-1 is the largest, + and that's again n//2-1. + */ + int64_t i, n = heap->length; + for (i = (n >> 1) - 1 ; i >= 0 ; i--) + if (siftup(heap, i, comparison, type)) + goto cleanup; + } else { + /* Cache friendly version of heapify() + ----------------------------------- + + Build-up a heap in O(n) time by performing siftup() operations + on nodes whose children are already heaps. + + The simplest way is to sift the nodes in reverse order from + n//2-1 to 0 inclusive. The downside is that children may be + out of cache by the time their parent is reached. + + A better way is to not wait for the children to go out of cache. + Once a sibling pair of child nodes have been sifted, immediately + sift their parent node (while the children are still in cache). + + Both ways build child heaps before their parents, so both ways + do the exact same number of comparisons and produce exactly + the same heap. The only difference is that the traversal + order is optimized for cache efficiency. + */ + int64_t m = heap->length >> 1; /* index of first childless node */ + int64_t leftmost = keep_top_bit(m + 1) - 1; /* leftmost node in row of m */ + int64_t mhalf = m >> 1; /* parent of first childless node */ + int64_t i; + for (i = leftmost - 1 ; i >= mhalf ; i--) { + int64_t j = i; + while (1) { + if (siftup(heap, j, comparison, type)) + goto cleanup; + if (!(j & 1)) + break; + j >>= 1; + } + } + + for (i = m - 1 ; i >= leftmost ; i--) { + int64_t j = i; + while (1) { + if (siftup(heap, j, comparison, type)) + goto cleanup; + if (!(j & 1)) + break; + j >>= 1; + } + } + } + cleanup: + ARRAY_DECREF(*heap); +} // vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0 diff --git a/builtins/array.h b/builtins/array.h index 711b89c4..a2c100d5 100644 --- a/builtins/array.h +++ b/builtins/array.h @@ -71,5 +71,10 @@ uint32_t Array$hash(const array_t *arr, const TypeInfo *type); int32_t Array$compare(const array_t *x, const array_t *y, const TypeInfo *type); bool Array$equal(const array_t *x, const array_t *y, const TypeInfo *type); CORD Array$as_text(const array_t *arr, bool colorize, const TypeInfo *type); +void Array$heapify(array_t *heap, closure_t comparison, const TypeInfo *type); +void Array$heap_push(array_t *heap, const void *item, closure_t comparison, const TypeInfo *type); +#define Array$heap_push_value(heap, _value, comparison, typeinfo) ({ __typeof(_value) value = _value; Array$heap_push(heap, &value, comparison, typeinfo); }) +void Array$heap_pop(array_t *heap, void *out, closure_t comparison, const TypeInfo *type); +#define Array$heap_pop_value(heap, comparison, typeinfo, type) ({ type value; Array$heap_pop(heap, &value, comparison, typeinfo); value; }) // vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0 @@ -1515,6 +1515,37 @@ CORD compile(env_t *env, ast_t *ast) comparison = CORD_all("(closure_t){.fn=generic_compare, .userdata=(void*)", compile_type_info(env, item_t), "}"); } return CORD_all("Array$", call->name, "(", self, ", ", comparison, ", ", compile_type_info(env, self_value_t), ")"); + } else if (streq(call->name, "heapify")) { + CORD self = compile_to_pointer_depth(env, call->self, 1, false); + CORD comparison; + if (call->args) { + type_t *item_ptr = Type(PointerType, .pointed=item_t, .is_stack=true); + type_t *fn_t = Type(FunctionType, .args=new(arg_t, .name="x", .type=item_ptr, .next=new(arg_t, .name="y", .type=item_ptr)), + .ret=Type(IntType, .bits=32)); + arg_t *arg_spec = new(arg_t, .name="by", .type=Type(ClosureType, .fn=fn_t)); + comparison = compile_arguments(env, ast, arg_spec, call->args); + } else { + comparison = CORD_all("((closure_t){.fn=generic_compare, .userdata=(void*)", compile_type_info(env, item_t), "})"); + } + return CORD_all("Array$heapify(", self, ", ", comparison, ", ", compile_type_info(env, self_value_t), ")"); + } else if (streq(call->name, "heap_push")) { + CORD self = compile_to_pointer_depth(env, call->self, 1, false); + type_t *item_ptr = Type(PointerType, .pointed=item_t, .is_stack=true); + type_t *fn_t = Type(FunctionType, .args=new(arg_t, .name="x", .type=item_ptr, .next=new(arg_t, .name="y", .type=item_ptr)), + .ret=Type(IntType, .bits=32)); + ast_t *default_cmp = FakeAST(InlineCCode, CORD_all("((closure_t){.fn=generic_compare, .userdata=(void*)", compile_type_info(env, item_t), "})")); + arg_t *arg_spec = new(arg_t, .name="item", .type=item_t, .next=new(arg_t, .name="by", .type=Type(ClosureType, .fn=fn_t), .default_val=default_cmp)); + CORD arg_code = compile_arguments(env, ast, arg_spec, call->args); + return CORD_all("Array$heap_push_value(", self, ", ", arg_code, ", ", compile_type_info(env, self_value_t), ")"); + } else if (streq(call->name, "heap_pop")) { + CORD self = compile_to_pointer_depth(env, call->self, 1, false); + type_t *item_ptr = Type(PointerType, .pointed=item_t, .is_stack=true); + type_t *fn_t = Type(FunctionType, .args=new(arg_t, .name="x", .type=item_ptr, .next=new(arg_t, .name="y", .type=item_ptr)), + .ret=Type(IntType, .bits=32)); + ast_t *default_cmp = FakeAST(InlineCCode, CORD_all("((closure_t){.fn=generic_compare, .userdata=(void*)", compile_type_info(env, item_t), "})")); + arg_t *arg_spec = new(arg_t, .name="by", .type=Type(ClosureType, .fn=fn_t), .default_val=default_cmp); + CORD arg_code = compile_arguments(env, ast, arg_spec, call->args); + return CORD_all("Array$heap_pop_value(", self, ", ", arg_code, ", ", compile_type_info(env, self_value_t), ", ", compile_type(env, item_t), ")"); } else if (streq(call->name, "clear")) { CORD self = compile_to_pointer_depth(env, call->self, 1, false); (void)compile_arguments(env, ast, NULL, call->args); diff --git a/test/arrays.tm b/test/arrays.tm index c8e34af4..be767fab 100644 --- a/test/arrays.tm +++ b/test/arrays.tm @@ -95,3 +95,22 @@ func main() = [30, 10, -20] >> ["A", "B", "C"]:sample(10, [1.0, 0.5, 0.0]) + + if yes + >> heap := [Int.random(max=50) for _ in 10] + >> heap:heapify() + >> heap + sorted := [:Int] + while #heap > 0 + sorted:insert(heap:heap_pop()) + >> sorted == sorted:sorted() + = yes + for _ in 10 + heap:heap_push(Int.random(max=50)) + >> heap + sorted = [:Int] + while #heap > 0 + sorted:insert(heap:heap_pop()) + >> sorted == sorted:sorted() + = yes + diff --git a/typecheck.c b/typecheck.c index efc105af..bd03588a 100644 --- a/typecheck.c +++ b/typecheck.c @@ -514,6 +514,9 @@ type_t *get_type(env_t *env, ast_t *ast) else if (streq(call->name, "clear")) return Type(VoidType); else if (streq(call->name, "slice")) return self_value_t; else if (streq(call->name, "reversed")) return self_value_t; + else if (streq(call->name, "heapify")) return Type(VoidType); + else if (streq(call->name, "heap_push")) return Type(VoidType); + else if (streq(call->name, "heap_pop")) return Match(self_value_t, ArrayType)->item_type; else code_err(ast, "There is no '%s' method for arrays", call->name); } case TableType: { |
