aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2024-04-19 13:29:04 -0400
committerBruce Hill <bruce@bruce-hill.com>2024-04-19 13:29:04 -0400
commit3b0dce04a08d864626f841383727fbdab339ec83 (patch)
treec466e7104ff75c76394fd4796cb0752b4f2c6a13
parent072bd523b97aacaf8639dd89a49f0c1a16d1d405 (diff)
Add heapify(), heap_push(), and heap_pop()
-rw-r--r--builtins/array.c189
-rw-r--r--builtins/array.h5
-rw-r--r--compile.c31
-rw-r--r--test/arrays.tm19
-rw-r--r--typecheck.c3
5 files changed, 247 insertions, 0 deletions
diff --git a/builtins/array.c b/builtins/array.c
index 5d8d5848..0c629610 100644
--- a/builtins/array.c
+++ b/builtins/array.c
@@ -453,5 +453,194 @@ public uint32_t Array$hash(const array_t *arr, const TypeInfo *type)
return hash;
}
}
+/*
+def _siftdown(heap, startpos, pos):
+ newitem = heap[pos]
+ # Follow the path to the root, moving parents down until finding a place
+ # newitem fits.
+ while pos > startpos:
+ parentpos = (pos - 1) >> 1
+ parent = heap[parentpos]
+ if newitem < parent:
+ heap[pos] = parent
+ pos = parentpos
+ continue
+ break
+ heap[pos] = newitem
+ */
+
+static int siftdown(array_t *heap, int64_t startpos, int64_t pos, closure_t comparison, const TypeInfo *type)
+{
+ assert(pos > 0 && pos < heap->length);
+ int64_t item_size = get_item_size(type);
+ char newitem[item_size];
+ memcpy(newitem, heap->data + heap->stride*pos, item_size);
+ while (pos > startpos) {
+ int64_t parentpos = (pos - 1) >> 1;
+ typedef int32_t (*cmp_fn_t)(void*, void*, void*);
+ int32_t cmp = ((cmp_fn_t)comparison.fn)(newitem, heap->data + heap->stride*parentpos, comparison.userdata);
+ if (cmp >= 0)
+ break;
+
+ memcpy(newitem, heap->data + heap->stride*pos, item_size);
+ // swap pos/parentpos:
+ memcpy(heap->data + heap->stride*pos, heap->data + heap->stride*parentpos, item_size);
+ memcpy(heap->data + heap->stride*parentpos, newitem, item_size);
+
+ pos = parentpos;
+ }
+ return 0;
+}
+
+static int siftup(array_t *heap, int64_t pos, closure_t comparison, const TypeInfo *type)
+{
+ int64_t endpos = heap->length;
+ int64_t startpos = pos;
+ assert(pos < endpos);
+
+ int64_t item_size = get_item_size(type);
+ /* Bubble up the smaller child until hitting a leaf. */
+ int64_t limit = endpos >> 1; /* smallest pos that has no child */
+ while (pos < limit) {
+ /* Set childpos to index of smaller child. */
+ int64_t childpos = 2*pos + 1; /* leftmost child position */
+ if (childpos + 1 < endpos) {
+ typedef int32_t (*cmp_fn_t)(void*, void*, void*);
+ int32_t cmp = ((cmp_fn_t)comparison.fn)(
+ heap->data + heap->stride*childpos,
+ heap->data + heap->stride*(childpos + 1),
+ comparison.userdata);
+ // if (cmp < 0)
+ // return -1;
+ childpos += (cmp >= 0);
+ }
+
+ if (heap->data_refcount >= 3) {
+ Array$compact(heap, type);
+ heap->data_refcount = 1;
+ }
+
+ /* Move the smaller child up. */
+ char buf[item_size];
+ memcpy(buf, heap->data + heap->stride*pos, item_size);
+ memcpy(heap->data + heap->stride*pos, heap->data + heap->stride*childpos, item_size);
+ memcpy(heap->data + heap->stride*childpos, buf, item_size);
+
+ pos = childpos;
+ }
+ /* Bubble it up to its final resting place (by sifting its parents down). */
+ return siftdown(heap, startpos, pos, comparison, type);
+}
+
+public void Array$heap_push(array_t *heap, const void *item, closure_t comparison, const TypeInfo *type)
+{
+ Array$insert(heap, item, 0, type);
+
+ if (heap->data_refcount > 0)
+ Array$compact(heap, type);
+
+ if (heap->length > 1)
+ siftdown(heap, 0, heap->length-1, comparison, type);
+}
+
+public void Array$heap_pop(array_t *heap, void *out, closure_t comparison, const TypeInfo *type)
+{
+ if (heap->length == 0)
+ fail("Attempt to pop from an empty array");
+
+ int64_t item_size = get_item_size(type);
+ memcpy(out, heap->data, item_size);
+ if (heap->length > 1) {
+ if (heap->data_refcount > 0)
+ Array$compact(heap, type);
+ memcpy(heap->data, heap->data + heap->stride*(heap->length-1), item_size);
+ --heap->length;
+ if (heap->length > 1)
+ siftup(heap, 0, comparison, type);
+ } else {
+ *heap = (array_t){};
+ }
+}
+
+static int64_t
+keep_top_bit(int64_t n)
+{
+ int i = 0;
+ for (; n > 1; n >>= 1)
+ ++i;
+ return n << i;
+}
+
+public void Array$heapify(array_t *heap, closure_t comparison, const TypeInfo *type)
+{
+ if (heap->data_refcount > 0)
+ Array$compact(heap, type);
+
+ ARRAY_INCREF(*heap);
+ /* For heaps likely to be bigger than L1 cache, we use the cache
+ friendly heapify function. For smaller heaps that fit entirely
+ in cache, we prefer the simpler algorithm with less branching.
+ */
+ if (heap->length <= 2500) {
+ /* Transform bottom-up. The largest index there's any point to
+ looking at is the largest with a child index in-range, so must
+ have 2*i + 1 < n, or i < (n-1)/2. If n is even = 2*j, this is
+ (2*j-1)/2 = j-1/2 so j-1 is the largest, which is n//2 - 1. If
+ n is odd = 2*j+1, this is (2*j+1-1)/2 = j so j-1 is the largest,
+ and that's again n//2-1.
+ */
+ int64_t i, n = heap->length;
+ for (i = (n >> 1) - 1 ; i >= 0 ; i--)
+ if (siftup(heap, i, comparison, type))
+ goto cleanup;
+ } else {
+ /* Cache friendly version of heapify()
+ -----------------------------------
+
+ Build-up a heap in O(n) time by performing siftup() operations
+ on nodes whose children are already heaps.
+
+ The simplest way is to sift the nodes in reverse order from
+ n//2-1 to 0 inclusive. The downside is that children may be
+ out of cache by the time their parent is reached.
+
+ A better way is to not wait for the children to go out of cache.
+ Once a sibling pair of child nodes have been sifted, immediately
+ sift their parent node (while the children are still in cache).
+
+ Both ways build child heaps before their parents, so both ways
+ do the exact same number of comparisons and produce exactly
+ the same heap. The only difference is that the traversal
+ order is optimized for cache efficiency.
+ */
+ int64_t m = heap->length >> 1; /* index of first childless node */
+ int64_t leftmost = keep_top_bit(m + 1) - 1; /* leftmost node in row of m */
+ int64_t mhalf = m >> 1; /* parent of first childless node */
+ int64_t i;
+ for (i = leftmost - 1 ; i >= mhalf ; i--) {
+ int64_t j = i;
+ while (1) {
+ if (siftup(heap, j, comparison, type))
+ goto cleanup;
+ if (!(j & 1))
+ break;
+ j >>= 1;
+ }
+ }
+
+ for (i = m - 1 ; i >= leftmost ; i--) {
+ int64_t j = i;
+ while (1) {
+ if (siftup(heap, j, comparison, type))
+ goto cleanup;
+ if (!(j & 1))
+ break;
+ j >>= 1;
+ }
+ }
+ }
+ cleanup:
+ ARRAY_DECREF(*heap);
+}
// vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0
diff --git a/builtins/array.h b/builtins/array.h
index 711b89c4..a2c100d5 100644
--- a/builtins/array.h
+++ b/builtins/array.h
@@ -71,5 +71,10 @@ uint32_t Array$hash(const array_t *arr, const TypeInfo *type);
int32_t Array$compare(const array_t *x, const array_t *y, const TypeInfo *type);
bool Array$equal(const array_t *x, const array_t *y, const TypeInfo *type);
CORD Array$as_text(const array_t *arr, bool colorize, const TypeInfo *type);
+void Array$heapify(array_t *heap, closure_t comparison, const TypeInfo *type);
+void Array$heap_push(array_t *heap, const void *item, closure_t comparison, const TypeInfo *type);
+#define Array$heap_push_value(heap, _value, comparison, typeinfo) ({ __typeof(_value) value = _value; Array$heap_push(heap, &value, comparison, typeinfo); })
+void Array$heap_pop(array_t *heap, void *out, closure_t comparison, const TypeInfo *type);
+#define Array$heap_pop_value(heap, comparison, typeinfo, type) ({ type value; Array$heap_pop(heap, &value, comparison, typeinfo); value; })
// vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0
diff --git a/compile.c b/compile.c
index 593b208d..de553edc 100644
--- a/compile.c
+++ b/compile.c
@@ -1515,6 +1515,37 @@ CORD compile(env_t *env, ast_t *ast)
comparison = CORD_all("(closure_t){.fn=generic_compare, .userdata=(void*)", compile_type_info(env, item_t), "}");
}
return CORD_all("Array$", call->name, "(", self, ", ", comparison, ", ", compile_type_info(env, self_value_t), ")");
+ } else if (streq(call->name, "heapify")) {
+ CORD self = compile_to_pointer_depth(env, call->self, 1, false);
+ CORD comparison;
+ if (call->args) {
+ type_t *item_ptr = Type(PointerType, .pointed=item_t, .is_stack=true);
+ type_t *fn_t = Type(FunctionType, .args=new(arg_t, .name="x", .type=item_ptr, .next=new(arg_t, .name="y", .type=item_ptr)),
+ .ret=Type(IntType, .bits=32));
+ arg_t *arg_spec = new(arg_t, .name="by", .type=Type(ClosureType, .fn=fn_t));
+ comparison = compile_arguments(env, ast, arg_spec, call->args);
+ } else {
+ comparison = CORD_all("((closure_t){.fn=generic_compare, .userdata=(void*)", compile_type_info(env, item_t), "})");
+ }
+ return CORD_all("Array$heapify(", self, ", ", comparison, ", ", compile_type_info(env, self_value_t), ")");
+ } else if (streq(call->name, "heap_push")) {
+ CORD self = compile_to_pointer_depth(env, call->self, 1, false);
+ type_t *item_ptr = Type(PointerType, .pointed=item_t, .is_stack=true);
+ type_t *fn_t = Type(FunctionType, .args=new(arg_t, .name="x", .type=item_ptr, .next=new(arg_t, .name="y", .type=item_ptr)),
+ .ret=Type(IntType, .bits=32));
+ ast_t *default_cmp = FakeAST(InlineCCode, CORD_all("((closure_t){.fn=generic_compare, .userdata=(void*)", compile_type_info(env, item_t), "})"));
+ arg_t *arg_spec = new(arg_t, .name="item", .type=item_t, .next=new(arg_t, .name="by", .type=Type(ClosureType, .fn=fn_t), .default_val=default_cmp));
+ CORD arg_code = compile_arguments(env, ast, arg_spec, call->args);
+ return CORD_all("Array$heap_push_value(", self, ", ", arg_code, ", ", compile_type_info(env, self_value_t), ")");
+ } else if (streq(call->name, "heap_pop")) {
+ CORD self = compile_to_pointer_depth(env, call->self, 1, false);
+ type_t *item_ptr = Type(PointerType, .pointed=item_t, .is_stack=true);
+ type_t *fn_t = Type(FunctionType, .args=new(arg_t, .name="x", .type=item_ptr, .next=new(arg_t, .name="y", .type=item_ptr)),
+ .ret=Type(IntType, .bits=32));
+ ast_t *default_cmp = FakeAST(InlineCCode, CORD_all("((closure_t){.fn=generic_compare, .userdata=(void*)", compile_type_info(env, item_t), "})"));
+ arg_t *arg_spec = new(arg_t, .name="by", .type=Type(ClosureType, .fn=fn_t), .default_val=default_cmp);
+ CORD arg_code = compile_arguments(env, ast, arg_spec, call->args);
+ return CORD_all("Array$heap_pop_value(", self, ", ", arg_code, ", ", compile_type_info(env, self_value_t), ", ", compile_type(env, item_t), ")");
} else if (streq(call->name, "clear")) {
CORD self = compile_to_pointer_depth(env, call->self, 1, false);
(void)compile_arguments(env, ast, NULL, call->args);
diff --git a/test/arrays.tm b/test/arrays.tm
index c8e34af4..be767fab 100644
--- a/test/arrays.tm
+++ b/test/arrays.tm
@@ -95,3 +95,22 @@ func main()
= [30, 10, -20]
>> ["A", "B", "C"]:sample(10, [1.0, 0.5, 0.0])
+
+ if yes
+ >> heap := [Int.random(max=50) for _ in 10]
+ >> heap:heapify()
+ >> heap
+ sorted := [:Int]
+ while #heap > 0
+ sorted:insert(heap:heap_pop())
+ >> sorted == sorted:sorted()
+ = yes
+ for _ in 10
+ heap:heap_push(Int.random(max=50))
+ >> heap
+ sorted = [:Int]
+ while #heap > 0
+ sorted:insert(heap:heap_pop())
+ >> sorted == sorted:sorted()
+ = yes
+
diff --git a/typecheck.c b/typecheck.c
index efc105af..bd03588a 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -514,6 +514,9 @@ type_t *get_type(env_t *env, ast_t *ast)
else if (streq(call->name, "clear")) return Type(VoidType);
else if (streq(call->name, "slice")) return self_value_t;
else if (streq(call->name, "reversed")) return self_value_t;
+ else if (streq(call->name, "heapify")) return Type(VoidType);
+ else if (streq(call->name, "heap_push")) return Type(VoidType);
+ else if (streq(call->name, "heap_pop")) return Match(self_value_t, ArrayType)->item_type;
else code_err(ast, "There is no '%s' method for arrays", call->name);
}
case TableType: {