diff --git a/stdlib/datatypes.h b/stdlib/datatypes.h index 5dddaa0..31a9963 100644 --- a/stdlib/datatypes.h +++ b/stdlib/datatypes.h @@ -53,6 +53,7 @@ typedef struct { typedef struct table_s { Array_t entries; + uint64_t hash; bucket_info_t *bucket_info; struct table_s *fallback; } Table_t; diff --git a/stdlib/tables.c b/stdlib/tables.c index df42ef6..dea7a45 100644 --- a/stdlib/tables.c +++ b/stdlib/tables.c @@ -228,6 +228,8 @@ public void *Table$reserve(Table_t *t, const void *key, const void *value, const if (!t || !key) return NULL; hshow(t); + t->hash = 0; + int64_t key_size = type->TableInfo.key->size, value_size = type->TableInfo.value->size; if (!t->bucket_info || t->bucket_info->count == 0) { @@ -334,6 +336,8 @@ public void Table$remove(Table_t *t, const void *key, const TypeInfo_t *type) found_it:; assert(bucket->occupied); + t->hash = 0; + // Always remove the last entry. If we need to remove some other entry, // swap the other entry into the last position and then remove the last // entry. This disturbs the ordering of the table, but keeps removal O(1) @@ -405,14 +409,27 @@ PUREFUNC public bool Table$equal(const void *vx, const void *vy, const TypeInfo_ if (vx == vy) return true; Table_t *x = (Table_t*)vx, *y = (Table_t*)vy; + if (x->hash && y->hash && x->hash != y->hash) + return false; + assert(type->tag == TableInfo); - if (Table$length(*x) != Table$length(*y)) + if (x->entries.length != y->entries.length) return false; if ((x->fallback != NULL) != (y->fallback != NULL)) return false; - return (Table$compare(x, y, type) == 0); + const TypeInfo_t *value_type = type->TableInfo.value; + size_t offset = value_offset(type); + for (int64_t i = 0; i < x->entries.length; i++) { + void *x_key = x->entries.data + i*x->entries.stride; + void *y_value = Table$get_raw(*y, x_key, type); + if (!y_value) return false; + void *x_value = x_key + offset; + if (!generic_equal(y_value, x_value, value_type)) + return false; + } + return true; } PUREFUNC public int32_t Table$compare(const void *vx, const void *vy, const TypeInfo_t *type) @@ -422,20 +439,90 @@ PUREFUNC public int32_t Table$compare(const void *vx, const void *vy, const Type Table_t *x = (Table_t*)vx, *y = (Table_t*)vy; assert(type->tag == TableInfo); auto table = type->TableInfo; - if (x->entries.length == 0) - return 0; - else if (x->entries.length != y->entries.length) - return (x->entries.length > y->entries.length) - (x->entries.length < y->entries.length); + // Sort empty tables before non-empty tables: + if (x->entries.length == 0 || y->entries.length == 0) + return ((x->entries.length > 0) - (y->entries.length > 0)); + + // Table comparison rules: + // - If two tables have different keys, then compare as if comparing a + // sorted array of the keys of the two tables: + // `x.keys:sorted() <> y.keys:sorted()` + // - Otherwise, compare as if comparing arrays of values for the sorted key + // arrays: + // `[x[k] for k in x.keys:sorted()] <> [y[k] for k in y.keys:sorted()]` + // + // We can do this in _linear_ time if we find the smallest `k` such that + // `x[k] != y[k]`, as well as the largest key in `x` and `y`. + + void *mismatched_key = NULL, *max_x_key = NULL; for (int64_t i = 0; i < x->entries.length; i++) { - void *x_key = x->entries.data + x->entries.stride * i; - void *y_key = y->entries.data + y->entries.stride * i; - int32_t diff = generic_compare(x_key, y_key, table.key); - if (diff != 0) return diff; - void *x_value = x_key + value_offset(type); - void *y_value = y_key + value_offset(type); - diff = generic_compare(x_value, y_value, table.value); - if (diff != 0) return diff; + void *key = x->entries.data + x->entries.stride * i; + if (max_x_key == NULL || generic_compare(key, max_x_key, table.key) > 0) + max_x_key = key; + + void *x_value = key + value_offset(type); + void *y_value = Table$get_raw(*y, key, type); + + if (!y_value || (table.value->size > 0 && !generic_equal(x_value, y_value, table.value))) { + if (mismatched_key == NULL || generic_compare(key, mismatched_key, table.key) < 0) + mismatched_key = key; + } + } + + // If the keys are not all equal, we gotta check to see if there exists a + // `y[k]` such that `k` is smaller than all keys that `x` has and `y` doesn't: + void *max_y_key = NULL; + for (int64_t i = 0; i < y->entries.length; i++) { + void *key = y->entries.data + y->entries.stride * i; + if (max_y_key == NULL || generic_compare(key, max_y_key, table.key) > 0) + max_y_key = key; + + void *y_value = key + value_offset(type); + void *x_value = Table$get_raw(*x, key, type); + if (!x_value || !generic_equal(x_value, y_value, table.value)) { + if (mismatched_key == NULL || generic_compare(key, mismatched_key, table.key) < 0) + mismatched_key = key; + } + } + + if (mismatched_key) { + void *x_value = Table$get_raw(*x, mismatched_key, type); + void *y_value = Table$get_raw(*y, mismatched_key, type); + if (x_value && y_value) { + return generic_compare(x_value, y_value, table.value); + } else if (y_value) { + // The smallest mismatched key is in Y, but not X. + // In this case, we should judge if the key is smaller than *any* + // key in X or if it's bigger than *every* key in X. + // Example 1: + // x={10, 20, 30} > y={10, 20, 25, 30} + // The smallest mismatched key is `25`, and we know that `x` is + // larger than `y` because `30 > 25`. + // Example 2: + // x={10, 20, 30} > y={10, 20, 30, 999} + // The smallest mismatched key is `999`, and we know that `x` is + // smaller than `y` because `30 < 999`. + return max_x_key ? generic_compare(max_x_key, mismatched_key, table.key) : -1; + } else { + assert(x_value); + // The smallest mismatched key is in X, but not Y. The same logic + // above applies, but reversed. + return max_y_key ? -generic_compare(max_y_key, mismatched_key, table.key) : 1; + } + } + + assert(x->entries.length == y->entries.length); + + // Assuming keys are the same, compare values: + if (table.value->size > 0) { + for (int64_t i = 0; i < x->entries.length; i++) { + void *key = x->entries.data + x->entries.stride * i; + void *x_value = key + value_offset(type); + void *y_value = Table$get_raw(*y, key, type); + int32_t diff = generic_compare(x_value, y_value, table.value); + if (diff != 0) return diff; + } } if (!x->fallback != !y->fallback) { @@ -451,16 +538,39 @@ PUREFUNC public uint64_t Table$hash(const void *obj, const TypeInfo_t *type) { assert(type->tag == TableInfo); Table_t *t = (Table_t*)obj; + if (t->hash != 0) + return t->hash; + // Table hashes are computed as: - // hash(hash(t.keys), hash(t.values), hash(t.fallback), hash(t.default)) + // hash(t.length, (xor: t.keys), (xor: t.values), t.fallback) // Where fallback and default hash to zero if absent auto table = type->TableInfo; - uint64_t components[] = { - Array$hash(&t->entries, Array$info(table.key)), - Array$hash(&t->entries + value_offset(type), Array$info(table.value)), - t->fallback ? Table$hash(t->fallback, type) : 0, + uint64_t keys_hash = 0, values_hash = 0; + size_t offset = value_offset(type); + if (table.value->size > 0) { + for (int64_t i = 0; i < t->entries.length; i++) { + keys_hash ^= generic_hash(t->entries.data + i*t->entries.stride, table.key); + values_hash ^= generic_hash(t->entries.data + i*t->entries.stride + offset, table.value); + } + } else { + for (int64_t i = 0; i < t->entries.length; i++) + keys_hash ^= generic_hash(t->entries.data + i*t->entries.stride, table.key); + } + + struct { + int64_t length; + uint64_t keys_hash, values_hash; + Table_t *fallback; + } components = { + t->entries.length, + keys_hash, + values_hash, + t->fallback, }; - return siphash24((void*)&components, sizeof(components)); + t->hash = siphash24((void*)&components, sizeof(components)); + if unlikely (t->hash == 0) + t->hash = 1234567; + return t->hash; } public Text_t Table$as_text(const void *obj, bool colorize, const TypeInfo_t *type) diff --git a/test/tables.tm b/test/tables.tm index e67c712..6abbbe3 100644 --- a/test/tables.tm +++ b/test/tables.tm @@ -88,3 +88,18 @@ func main(): >> t4 = &{"one":999, "two":222} + do: + >> {1:1, 2:2} == {2:2, 1:1} + = yes + >> {1:1, 2:2} == {1:1, 2:999} + = no + + >> {1:1, 2:2} <> {2:2, 1:1} + = 0 + >> [{:Int:Int}, {0:0}, {99:99}, {1:1, 2:2, 3:3}, {1:1, 99:99, 3:3}, {1:1, 2:-99, 3:3}, {1:1, 99:-99, 3:4}]:sorted() + = [{}, {0:0}, {1:1, 2:-99, 3:3}, {1:1, 2:2, 3:3}, {1:1, 99:99, 3:3}, {1:1, 99:-99, 3:4}, {99:99}] + + >> [{:Int}, {1}, {2}, {99}, {0, 3}, {1, 2}, {99}]:sorted() + = [{}, {0, 3}, {1}, {1, 2}, {2}, {99}, {99}] + +