diff options
| author | Bruce Hill <bruce@bruce-hill.com> | 2024-02-04 21:13:50 -0500 |
|---|---|---|
| committer | Bruce Hill <bruce@bruce-hill.com> | 2024-02-04 21:13:50 -0500 |
| commit | adde91636f04ae7544dba1ca5c6c1a40c074edb9 (patch) | |
| tree | dfeb8c0c16fda6a87ef30b048b070ee4cb175a78 /builtins/table.c | |
| parent | b08a0d3e2bf45bae11c982dd24d0292d6436b993 (diff) | |
Builtins
Diffstat (limited to 'builtins/table.c')
| -rw-r--r-- | builtins/table.c | 558 |
1 files changed, 558 insertions, 0 deletions
diff --git a/builtins/table.c b/builtins/table.c new file mode 100644 index 00000000..7ba33b1c --- /dev/null +++ b/builtins/table.c @@ -0,0 +1,558 @@ + +// table.c - C Hash table implementation for SSS +// Copyright 2023 Bruce Hill +// Provided under the MIT license with the Commons Clause +// See included LICENSE for details. + +// Hash table (aka Dictionary) Implementation +// Hash keys and values are stored *by value* +// The hash insertion/lookup implementation is based on Lua's tables, +// which use a chained scatter with Brent's variation. + +#include <assert.h> +#include <gc.h> +#include <stdalign.h> +#include <stdarg.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <sys/param.h> + +#include "../SipHash/halfsiphash.h" +#include "../util.h" +#include "array.h" +#include "table.h" +#include "types.h" + +// #define DEBUG_TABLES + +#ifdef DEBUG_TABLES +#define hdebug(fmt, ...) printf("\x1b[2m" fmt "\x1b[m" __VA_OPT__(,) __VA_ARGS__) +#else +#define hdebug(...) (void)0 +#endif + +// Helper accessors for type functions/values: +#define HASH_KEY(t, k) (generic_hash((k), type->TableInfo.key) % ((t)->bucket_info->count)) +#define EQUAL_KEYS(x, y) (generic_equal((x), (y), type->TableInfo.key)) +#define ENTRY_SIZE (type->TableInfo.entry_size) +#define VALUE_OFFSET (type->TableInfo.value_offset) +#define END_OF_CHAIN UINT32_MAX + +#define GET_ENTRY(t, i) ((t)->entries.data + (t)->entries.stride*(i)) + +extern const void *SSS_HASH_VECTOR; + +extern CORD CString_cord(const char **s, bool colorize, const TypeInfo *type); +extern uint32_t CString_hash(const char **s, const TypeInfo *type); +extern uint32_t CString_compare(const char **x, const char **y, const TypeInfo *type); +static TypeInfo CString_typeinfo = { + .name="CString", + .size=sizeof(char*), + .align=alignof(char*), + .tag=CustomInfo, + .CustomInfo={ + .cord=(void*)CString_cord, + .hash=(void*)CString_hash, + .compare=(void*)CString_compare, + }, +}; + +TypeInfo MemoryPointer_typeinfo = { + .name="@Memory", + .size=sizeof(void*), + .align=alignof(void*), + .tag=PointerInfo, + .PointerInfo={ + .sigil="@", + .pointed=NULL, + }, +}; + +TypeInfo CStringToVoidStarTable_type = { + .name="{CString=>@Memory}", + .size=sizeof(table_t), + .align=alignof(table_t), + .tag=TableInfo, + .TableInfo={.key=&CString_typeinfo,.value=&MemoryPointer_typeinfo, + .entry_size=16, .value_offset=8}, +}; + +static inline void hshow(const table_t *t) +{ + hdebug("{"); + for (uint32_t i = 0; t->bucket_info && i < t->bucket_info->count; i++) { + if (i > 0) hdebug(" "); + if (t->bucket_info->buckets[i].occupied) + hdebug("[%d]=%d(%d)", i, t->bucket_info->buckets[i].index, t->bucket_info->buckets[i].next_bucket); + else + hdebug("[%d]=_", i); + } + hdebug("}\n"); +} + +static void maybe_copy_on_write(table_t *t, const TypeInfo *type) +{ + if (t->entries.copy_on_write) { + Array_compact(&t->entries, type->TableInfo.entry_size); + } + + if (t->bucket_info && t->bucket_info->copy_on_write) { + int64_t size = sizeof(bucket_info_t) + t->bucket_info->count*sizeof(bucket_t); + t->bucket_info = memcpy(GC_MALLOC(size), t->bucket_info, size); + t->bucket_info->copy_on_write = 0; + } +} + +public void Table_mark_copy_on_write(table_t *t) +{ + t->entries.copy_on_write = 1; + if (t->bucket_info) t->bucket_info->copy_on_write = 1; +} + +// Return address of value or NULL +public void *Table_get_raw(const table_t *t, const void *key, const TypeInfo *type) +{ + assert(type->tag == TableInfo); + if (!t || !key || !t->bucket_info) return NULL; + + uint32_t hash = HASH_KEY(t, key); + hshow(t); + hdebug("Getting value with initial probe at %u\n", hash); + bucket_t *buckets = t->bucket_info->buckets; + for (uint32_t i = hash; buckets[i].occupied; i = buckets[i].next_bucket) { + hdebug("Checking against key in bucket %u\n", i); + void *entry = GET_ENTRY(t, buckets[i].index); + if (EQUAL_KEYS(entry, key)) { + hdebug("Found key!\n"); + return entry + VALUE_OFFSET; + } + if (buckets[i].next_bucket == END_OF_CHAIN) + break; + } + return NULL; +} + +public void *Table_get(const table_t *t, const void *key, const TypeInfo *type) +{ + assert(type->tag == TableInfo); + for (const table_t *iter = t; iter; iter = iter->fallback) { + void *ret = Table_get_raw(iter, key, type); + if (ret) return ret; + } + for (const table_t *iter = t; iter; iter = iter->fallback) { + if (iter->default_value) return iter->default_value; + } + return NULL; +} + +static void Table_set_bucket(table_t *t, const void *entry, int32_t index, const TypeInfo *type) +{ + assert(t->bucket_info); + hshow(t); + const void *key = entry; + bucket_t *buckets = t->bucket_info->buckets; + uint32_t hash = HASH_KEY(t, key); + hdebug("Hash value (mod %u) = %u\n", t->bucket_info->count, hash); + bucket_t *bucket = &buckets[hash]; + if (!bucket->occupied) { + hdebug("Got an empty space\n"); + // Empty space: + bucket->occupied = 1; + bucket->index = index; + bucket->next_bucket = END_OF_CHAIN; + hshow(t); + return; + } + + hdebug("Collision detected in bucket %u (entry %u)\n", hash, bucket->index); + + while (buckets[t->bucket_info->last_free].occupied) { + assert(t->bucket_info->last_free > 0); + --t->bucket_info->last_free; + } + + uint32_t collided_hash = HASH_KEY(t, GET_ENTRY(t, bucket->index)); + if (collided_hash != hash) { // Collided with a mid-chain entry + hdebug("Hit a mid-chain entry at bucket %u (chain starting at %u)\n", hash, collided_hash); + // Find chain predecessor + uint32_t predecessor = collided_hash; + while (buckets[predecessor].next_bucket != hash) + predecessor = buckets[predecessor].next_bucket; + + // Move mid-chain entry to free space and update predecessor + buckets[predecessor].next_bucket = t->bucket_info->last_free; + buckets[t->bucket_info->last_free] = *bucket; + } else { // Collided with the start of a chain + hdebug("Hit start of a chain\n"); + uint32_t end_of_chain = hash; + while (buckets[end_of_chain].next_bucket != END_OF_CHAIN) + end_of_chain = buckets[end_of_chain].next_bucket; + hdebug("Appending to chain\n"); + // Chain now ends on the free space: + buckets[end_of_chain].next_bucket = t->bucket_info->last_free; + bucket = &buckets[t->bucket_info->last_free]; + } + + bucket->occupied = 1; + bucket->index = index; + bucket->next_bucket = END_OF_CHAIN; + hshow(t); +} + +static void hashmap_resize_buckets(table_t *t, uint32_t new_capacity, const TypeInfo *type) +{ + hdebug("About to resize from %u to %u\n", t->bucket_info ? t->bucket_info->count : 0, new_capacity); + hshow(t); + int64_t alloc_size = sizeof(bucket_info_t) + (int64_t)(new_capacity)*sizeof(bucket_t); + t->bucket_info = GC_MALLOC_ATOMIC(alloc_size); + memset(t->bucket_info->buckets, 0, (int64_t)new_capacity * sizeof(bucket_t)); + t->bucket_info->count = new_capacity; + t->bucket_info->last_free = new_capacity-1; + // Rehash: + for (int64_t i = 0; i < Table_length(t); i++) { + hdebug("Rehashing %u\n", i); + Table_set_bucket(t, GET_ENTRY(t, i), i, type); + } + + hshow(t); + hdebug("Finished resizing\n"); +} + +// Return address of value +public void *Table_reserve(table_t *t, const void *key, const void *value, const TypeInfo *type) +{ + assert(type->tag == TableInfo); + if (!t || !key) return NULL; + hshow(t); + + int64_t key_size = type->TableInfo.key->size, + value_size = type->TableInfo.value->size; + if (!t->bucket_info || t->bucket_info->count == 0) { + hashmap_resize_buckets(t, 4, type); + } else { + // Check if we are clobbering a value: + void *value_home = Table_get_raw(t, key, type); + if (value_home) { // Update existing slot + // Ensure that `value_home` is still inside t->entries, even if COW occurs + ptrdiff_t offset = value_home - t->entries.data; + maybe_copy_on_write(t, type); + value_home = t->entries.data + offset; + + if (value && value_size > 0) + memcpy(value_home, value, value_size); + + return value_home; + } + } + // Otherwise add a new entry: + + // Resize buckets if necessary + if (t->entries.length >= (int64_t)t->bucket_info->count) { + uint32_t newsize = t->bucket_info->count + MIN(t->bucket_info->count, 64); + hashmap_resize_buckets(t, newsize, type); + } + + if (!value && value_size > 0) { + for (table_t *iter = t->fallback; iter; iter = iter->fallback) { + value = Table_get_raw(iter, key, type); + if (value) break; + } + for (table_t *iter = t; !value && iter; iter = iter->fallback) { + if (iter->default_value) value = iter->default_value; + } + } + + maybe_copy_on_write(t, type); + + char buf[ENTRY_SIZE] = {}; + memcpy(buf, key, key_size); + if (value && value_size > 0) + memcpy(buf + VALUE_OFFSET, value, value_size); + else + memset(buf + VALUE_OFFSET, 0, value_size); + Array_insert(&t->entries, buf, 0, ENTRY_SIZE); + + int64_t entry_index = t->entries.length-1; + void *entry = GET_ENTRY(t, entry_index); + Table_set_bucket(t, entry, entry_index, type); + return entry + VALUE_OFFSET; +} + +public void Table_set(table_t *t, const void *key, const void *value, const TypeInfo *type) +{ + assert(type->tag == TableInfo); + (void)Table_reserve(t, key, value, type); +} + +public void Table_remove(table_t *t, const void *key, const TypeInfo *type) +{ + assert(type->tag == TableInfo); + if (!t || Table_length(t) == 0) return; + + // TODO: this work doesn't need to be done if the key is already missing + maybe_copy_on_write(t, type); + + // If unspecified, pop a random key: + if (!key) { + hdebug("Popping random key\n"); + uint32_t index = arc4random_uniform(t->entries.length); + key = GET_ENTRY(t, index); + } + + // Steps: look up the bucket for the removed key + // If missing, then return immediately + // Swap last key/value into the removed bucket's index1 + // Zero out the last key/value and decrement the count + // Find the last key/value's bucket and update its index1 + // Look up the bucket for the removed key + // If bucket is first in chain: + // Move bucket->next to bucket's spot + // zero out bucket->next's old spot + // maybe update lastfree_index1 to second-in-chain's index + // Else: + // set prev->next = bucket->next + // zero out bucket + // maybe update lastfree_index1 to removed bucket's index + + uint32_t hash = HASH_KEY(t, key); + hdebug("Removing key with hash %u\n", hash); + bucket_t *bucket, *prev = NULL; + for (uint32_t i = hash; t->bucket_info->buckets[i].occupied; i = t->bucket_info->buckets[i].next_bucket) { + if (EQUAL_KEYS(GET_ENTRY(t, t->bucket_info->buckets[i].index), key)) { + bucket = &t->bucket_info->buckets[i]; + hdebug("Found key to delete in bucket %u\n", i); + goto found_it; + } + if (t->bucket_info->buckets[i].next_bucket == END_OF_CHAIN) + return; + prev = &t->bucket_info->buckets[i]; + } + return; + + found_it:; + assert(bucket->occupied); + + // Always remove the last entry. If we need to remove some other entry, + // swap the other entry into the last position and then remove the last + // entry. This disturbs the ordering of the table, but keeps removal O(1) + // instead of O(N) + int64_t last_entry = t->entries.length-1; + if (bucket->index != last_entry) { + hdebug("Removing key/value from the middle of the entries array\n"); + + // Find the bucket that points to the last entry's index: + uint32_t i = HASH_KEY(t, GET_ENTRY(t, last_entry)); + while (t->bucket_info->buckets[i].index != last_entry) + i = t->bucket_info->buckets[i].next_bucket; + // Update the bucket to point to the last entry's new home (the space + // where the removed entry currently sits): + t->bucket_info->buckets[i].index = bucket->index; + + // Clobber the entry being removed (in the middle of the array) with + // the last entry: + memcpy(GET_ENTRY(t, bucket->index), GET_ENTRY(t, last_entry), ENTRY_SIZE); + } + + // Last entry is being removed, so clear it out to be safe: + memset(GET_ENTRY(t, last_entry), 0, ENTRY_SIZE); + + Array_remove(&t->entries, t->entries.length, 1, ENTRY_SIZE); + + int64_t bucket_to_clear; + if (prev) { // Middle (or end) of a chain + hdebug("Removing from middle of a chain\n"); + bucket_to_clear = (bucket - t->bucket_info->buckets); + prev->next_bucket = bucket->next_bucket; + } else if (bucket->next_bucket != END_OF_CHAIN) { // Start of a chain + hdebug("Removing from start of a chain\n"); + bucket_to_clear = bucket->next_bucket; + *bucket = t->bucket_info->buckets[bucket_to_clear]; + } else { // Empty chain + hdebug("Removing from empty chain\n"); + bucket_to_clear = (bucket - t->bucket_info->buckets); + } + + t->bucket_info->buckets[bucket_to_clear] = (bucket_t){0}; + if (bucket_to_clear > t->bucket_info->last_free) + t->bucket_info->last_free = bucket_to_clear; + + hshow(t); +} + +public void *Table_entry(const table_t *t, int64_t n) +{ + if (n < 1 || n > Table_length(t)) + return NULL; + return GET_ENTRY(t, n-1); +} + +public void Table_clear(table_t *t) +{ + memset(t, 0, sizeof(table_t)); +} + +public bool Table_equal(const table_t *x, const table_t *y, const TypeInfo *type) +{ + assert(type->tag == TableInfo); + if (Table_length(x) != Table_length(y)) + return false; + + if ((x->default_value != NULL) != (y->default_value != NULL)) + return false; + + if ((x->fallback != NULL) != (y->fallback != NULL)) + return false; + + const TypeInfo *value_type = type->TableInfo.value; + for (int64_t i = 0, length = Table_length(x); i < length; i++) { + void *x_key = GET_ENTRY(x, i); + void *x_value = x_key + VALUE_OFFSET; + void *y_value = Table_get_raw(y, x_key, type); + if (!y_value) + return false; + if (!generic_equal(x_value, y_value, value_type)) + return false; + } + + if (x->default_value && y->default_value + && !generic_equal(x->default_value, y->default_value, value_type)) + return false; + + if (x->fallback && y->fallback + && !Table_equal(x->fallback, y->fallback, type)) + return false; + + return true; +} + +public int32_t Table_compare(const table_t *x, const table_t *y, const TypeInfo *type) +{ + assert(type->tag == TableInfo); + auto table = type->TableInfo; + struct { + const char *name; + const TypeInfo *type; + } member_data[] = {{"key", table.key}, {"value", table.value}}; + TypeInfo entry_type = { + .name="Entry", + .size=ENTRY_SIZE, + .align=MAX(table.key->align, table.value->align), + .tag=StructInfo, + .StructInfo={ + .members=(array_t){.data=member_data, .length=2, .stride=sizeof(member_data[0])}, + } + }; + array_t x_entries = x->entries, y_entries = y->entries; + Array_sort(&x_entries, &entry_type); + Array_sort(&y_entries, &entry_type); + return Array_compare(&x_entries, &y_entries, &entry_type); +} + +public uint32_t Table_hash(const table_t *t, const TypeInfo *type) +{ + assert(type->tag == TableInfo); + // Table hashes are computed as: + // hash(#t, xor(hash(k) for k in t.keys), xor(hash(v) for v in t.values), hash(t.fallback), hash(t.default)) + // Where fallback and default hash to zero if absent + auto table = type->TableInfo; + int64_t value_offset = table.value_offset; + + uint32_t key_hashes = 0, value_hashes = 0, fallback_hash = 0, default_hash = 0; + for (int64_t i = 0, length = Table_length(t); i < length; i++) { + void *entry = GET_ENTRY(t, i); + key_hashes ^= generic_hash(entry, table.key); + value_hashes ^= generic_hash(entry + value_offset, table.value); + } + + if (t->fallback) + fallback_hash = Table_hash(t->fallback, type); + + if (t->default_value) + default_hash = generic_hash(t->default_value, table.value); + + struct { int64_t len; uint32_t k, v, f, d; } components = { + Table_length(t), + key_hashes, + value_hashes, + fallback_hash, + default_hash, + }; + uint32_t hash; + halfsiphash(&components, sizeof(components), SSS_HASH_VECTOR, (uint8_t*)&hash, sizeof(hash)); + return hash; +} + +public CORD Table_cord(const table_t *t, bool colorize, const TypeInfo *type) +{ + assert(type->tag == TableInfo); + auto table = type->TableInfo; + int64_t value_offset = table.value_offset; + CORD c = "{"; + for (int64_t i = 0, length = Table_length(t); i < length; i++) { + if (i > 0) + c = CORD_cat(c, ", "); + void *entry = GET_ENTRY(t, i); + c = CORD_cat(c, generic_cord(entry, colorize, table.key)); + c = CORD_cat(c, "=>"); + c = CORD_cat(c, generic_cord(entry + value_offset, colorize, table.value)); + } + + if (t->fallback) { + c = CORD_cat(c, "; fallback="); + c = CORD_cat(c, Table_cord(t->fallback, colorize, type)); + } + + if (t->default_value) { + c = CORD_cat(c, "; default="); + c = CORD_cat(c, generic_cord(t->default_value, colorize, table.value)); + } + + c = CORD_cat(c, "}"); + return c; +} + +public table_t Table_from_entries(array_t entries, const TypeInfo *type) +{ + assert(type->tag == TableInfo); + table_t t = {.entries=entries}; + for (int64_t i = 0; i < Table_length(&t); i++) { + hdebug("Rehashing %u\n", i); + Table_set_bucket(&t, GET_ENTRY(&t, i), i, type); + } + return t; +} + +void *Table_str_get(const table_t *t, const char *key) +{ + void **ret = Table_get(t, &key, &CStringToVoidStarTable_type); + return ret ? *ret : NULL; +} + +void *Table_str_get_raw(const table_t *t, const char *key) +{ + void **ret = Table_get_raw(t, &key, &CStringToVoidStarTable_type); + return ret ? *ret : NULL; +} + +void *Table_str_reserve(table_t *t, const char *key, const void *value) +{ + return Table_reserve(t, &key, &value, &CStringToVoidStarTable_type); +} + +void Table_str_set(table_t *t, const char *key, const void *value) +{ + Table_set(t, &key, &value, &CStringToVoidStarTable_type); +} + +void Table_str_remove(table_t *t, const char *key) +{ + return Table_remove(t, &key, &CStringToVoidStarTable_type); +} + +void *Table_str_entry(const table_t *t, int64_t n) +{ + return Table_entry(t, n); +} + +// vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1 |
