Refactor table methods to take table structs where possible

This commit is contained in:
Bruce Hill 2024-03-08 14:33:54 -05:00
parent 55eacb8a04
commit 8427037bb9
7 changed files with 74 additions and 68 deletions

View File

@ -35,11 +35,11 @@
#endif
// Helper accessors for type functions/values:
#define HASH_KEY(t, k) (generic_hash((k), type->TableInfo.key) % ((t)->bucket_info->count))
#define HASH_KEY(t, k) (generic_hash((k), type->TableInfo.key) % ((t).bucket_info->count))
#define EQUAL_KEYS(x, y) (generic_equal((x), (y), type->TableInfo.key))
#define END_OF_CHAIN UINT32_MAX
#define GET_ENTRY(t, i) ((t)->entries.data + (t)->entries.stride*(i))
#define GET_ENTRY(t, i) ((t).entries.data + (t).entries.stride*(i))
#define ENTRIES_TYPE(type) (&(TypeInfo){.size=sizeof(array_t), .align=__alignof__(array_t), .tag=ArrayInfo, .ArrayInfo.item=(&(TypeInfo){.size=entry_size(type), .align=entry_align(type), .tag=OpaqueInfo})})
const TypeInfo MemoryPointer = {
@ -117,15 +117,15 @@ public void Table_mark_copy_on_write(table_t *t)
}
// Return address of value or NULL
public void *Table_get_raw(const table_t *t, const void *key, const TypeInfo *type)
public void *Table_get_raw(table_t t, const void *key, const TypeInfo *type)
{
assert(type->tag == TableInfo);
if (!t || !key || !t->bucket_info) return NULL;
if (!key || !t.bucket_info) return NULL;
uint32_t hash = HASH_KEY(t, key);
hshow(t);
hshow(&t);
hdebug("Getting value with initial probe at %u\n", hash);
bucket_t *buckets = t->bucket_info->buckets;
bucket_t *buckets = t.bucket_info->buckets;
for (uint32_t i = hash; buckets[i].occupied; i = buckets[i].next_bucket) {
hdebug("Checking against key in bucket %u\n", i);
void *entry = GET_ENTRY(t, buckets[i].index);
@ -139,14 +139,14 @@ public void *Table_get_raw(const table_t *t, const void *key, const TypeInfo *ty
return NULL;
}
public void *Table_get(const table_t *t, const void *key, const TypeInfo *type)
public void *Table_get(table_t t, const void *key, const TypeInfo *type)
{
assert(type->tag == TableInfo);
for (const table_t *iter = t; iter; iter = iter->fallback) {
void *ret = Table_get_raw(iter, key, type);
for (const table_t *iter = &t; iter; iter = iter->fallback) {
void *ret = Table_get_raw(*iter, key, type);
if (ret) return ret;
}
for (const table_t *iter = t; iter; iter = iter->fallback) {
for (const table_t *iter = &t; iter; iter = iter->fallback) {
if (iter->default_value) return iter->default_value;
}
return NULL;
@ -158,7 +158,7 @@ static void Table_set_bucket(table_t *t, const void *entry, int32_t index, const
hshow(t);
const void *key = entry;
bucket_t *buckets = t->bucket_info->buckets;
uint32_t hash = HASH_KEY(t, key);
uint32_t hash = HASH_KEY(*t, key);
hdebug("Hash value (mod %u) = %u\n", t->bucket_info->count, hash);
bucket_t *bucket = &buckets[hash];
if (!bucket->occupied) {
@ -178,7 +178,7 @@ static void Table_set_bucket(table_t *t, const void *entry, int32_t index, const
--t->bucket_info->last_free;
}
uint32_t collided_hash = HASH_KEY(t, GET_ENTRY(t, bucket->index));
uint32_t collided_hash = HASH_KEY(*t, GET_ENTRY(*t, bucket->index));
if (collided_hash != hash) { // Collided with a mid-chain entry
hdebug("Hit a mid-chain entry at bucket %u (chain starting at %u)\n", hash, collided_hash);
// Find chain predecessor
@ -216,9 +216,9 @@ static void hashmap_resize_buckets(table_t *t, uint32_t new_capacity, const Type
t->bucket_info->count = new_capacity;
t->bucket_info->last_free = new_capacity-1;
// Rehash:
for (int64_t i = 0; i < Table_length(t); i++) {
for (int64_t i = 0; i < Table_length(*t); i++) {
hdebug("Rehashing %u\n", i);
Table_set_bucket(t, GET_ENTRY(t, i), i, type);
Table_set_bucket(t, GET_ENTRY(*t, i), i, type);
}
hshow(t);
@ -238,7 +238,7 @@ public void *Table_reserve(table_t *t, const void *key, const void *value, const
hashmap_resize_buckets(t, 4, type);
} else {
// Check if we are clobbering a value:
void *value_home = Table_get_raw(t, key, type);
void *value_home = Table_get_raw(*t, key, type);
if (value_home) { // Update existing slot
// Ensure that `value_home` is still inside t->entries, even if COW occurs
ptrdiff_t offset = value_home - t->entries.data;
@ -261,7 +261,7 @@ public void *Table_reserve(table_t *t, const void *key, const void *value, const
if (!value && value_size > 0) {
for (table_t *iter = t->fallback; iter; iter = iter->fallback) {
value = Table_get_raw(iter, key, type);
value = Table_get_raw(*iter, key, type);
if (value) break;
}
for (table_t *iter = t; !value && iter; iter = iter->fallback) {
@ -280,7 +280,7 @@ public void *Table_reserve(table_t *t, const void *key, const void *value, const
Array__insert(&t->entries, buf, 0, ENTRIES_TYPE(type));
int64_t entry_index = t->entries.length-1;
void *entry = GET_ENTRY(t, entry_index);
void *entry = GET_ENTRY(*t, entry_index);
Table_set_bucket(t, entry, entry_index, type);
return entry + value_offset(type);
}
@ -294,7 +294,7 @@ public void Table_set(table_t *t, const void *key, const void *value, const Type
public void Table_remove(table_t *t, const void *key, const TypeInfo *type)
{
assert(type->tag == TableInfo);
if (!t || Table_length(t) == 0) return;
if (!t || Table_length(*t) == 0) return;
// TODO: this work doesn't need to be done if the key is already missing
maybe_copy_on_write(t, type);
@ -303,7 +303,7 @@ public void Table_remove(table_t *t, const void *key, const TypeInfo *type)
if (!key) {
hdebug("Popping random key\n");
uint32_t index = arc4random_uniform(t->entries.length);
key = GET_ENTRY(t, index);
key = GET_ENTRY(*t, index);
}
// Steps: look up the bucket for the removed key
@ -321,11 +321,11 @@ public void Table_remove(table_t *t, const void *key, const TypeInfo *type)
// zero out bucket
// maybe update lastfree_index1 to removed bucket's index
uint32_t hash = HASH_KEY(t, key);
uint32_t hash = HASH_KEY(*t, key);
hdebug("Removing key with hash %u\n", hash);
bucket_t *bucket, *prev = NULL;
for (uint32_t i = hash; t->bucket_info->buckets[i].occupied; i = t->bucket_info->buckets[i].next_bucket) {
if (EQUAL_KEYS(GET_ENTRY(t, t->bucket_info->buckets[i].index), key)) {
if (EQUAL_KEYS(GET_ENTRY(*t, t->bucket_info->buckets[i].index), key)) {
bucket = &t->bucket_info->buckets[i];
hdebug("Found key to delete in bucket %u\n", i);
goto found_it;
@ -348,7 +348,7 @@ public void Table_remove(table_t *t, const void *key, const TypeInfo *type)
hdebug("Removing key/value from the middle of the entries array\n");
// Find the bucket that points to the last entry's index:
uint32_t i = HASH_KEY(t, GET_ENTRY(t, last_entry));
uint32_t i = HASH_KEY(*t, GET_ENTRY(*t, last_entry));
while (t->bucket_info->buckets[i].index != last_entry)
i = t->bucket_info->buckets[i].next_bucket;
// Update the bucket to point to the last entry's new home (the space
@ -357,11 +357,11 @@ public void Table_remove(table_t *t, const void *key, const TypeInfo *type)
// Clobber the entry being removed (in the middle of the array) with
// the last entry:
memcpy(GET_ENTRY(t, bucket->index), GET_ENTRY(t, last_entry), entry_size(type));
memcpy(GET_ENTRY(*t, bucket->index), GET_ENTRY(*t, last_entry), entry_size(type));
}
// Last entry is being removed, so clear it out to be safe:
memset(GET_ENTRY(t, last_entry), 0, entry_size(type));
memset(GET_ENTRY(*t, last_entry), 0, entry_size(type));
Array__remove(&t->entries, t->entries.length, 1, ENTRIES_TYPE(type));
@ -386,7 +386,7 @@ public void Table_remove(table_t *t, const void *key, const TypeInfo *type)
hshow(t);
}
public void *Table_entry(const table_t *t, int64_t n)
public void *Table_entry(table_t t, int64_t n)
{
if (n < 1 || n > Table_length(t))
return NULL;
@ -401,7 +401,7 @@ public void Table_clear(table_t *t)
public bool Table_equal(const table_t *x, const table_t *y, const TypeInfo *type)
{
assert(type->tag == TableInfo);
if (Table_length(x) != Table_length(y))
if (Table_length(*x) != Table_length(*y))
return false;
if ((x->default_value != NULL) != (y->default_value != NULL))
@ -411,10 +411,10 @@ public bool Table_equal(const table_t *x, const table_t *y, const TypeInfo *type
return false;
const TypeInfo *value_type = type->TableInfo.value;
for (int64_t i = 0, length = Table_length(x); i < length; i++) {
void *x_key = GET_ENTRY(x, i);
for (int64_t i = 0, length = Table_length(*x); i < length; i++) {
void *x_key = GET_ENTRY(*x, i);
void *x_value = x_key + value_offset(type);
void *y_value = Table_get_raw(y, x_key, type);
void *y_value = Table_get_raw(*y, x_key, type);
if (!y_value)
return false;
if (!generic_equal(x_value, y_value, value_type))
@ -481,8 +481,8 @@ public uint32_t Table_hash(const table_t *t, const TypeInfo *type)
int64_t val_off = value_offset(type);
uint32_t key_hashes = 0, value_hashes = 0, fallback_hash = 0, default_hash = 0;
for (int64_t i = 0, length = Table_length(t); i < length; i++) {
void *entry = GET_ENTRY(t, i);
for (int64_t i = 0, length = Table_length(*t); i < length; i++) {
void *entry = GET_ENTRY(*t, i);
key_hashes ^= generic_hash(entry, table.key);
value_hashes ^= generic_hash(entry + val_off, table.value);
}
@ -494,7 +494,7 @@ public uint32_t Table_hash(const table_t *t, const TypeInfo *type)
default_hash = generic_hash(t->default_value, table.value);
struct { int64_t len; uint32_t k, v, f, d; } components = {
Table_length(t),
Table_length(*t),
key_hashes,
value_hashes,
fallback_hash,
@ -515,10 +515,10 @@ public CORD Table_as_text(const table_t *t, bool colorize, const TypeInfo *type)
int64_t val_off = value_offset(type);
CORD c = "{";
for (int64_t i = 0, length = Table_length(t); i < length; i++) {
for (int64_t i = 0, length = Table_length(*t); i < length; i++) {
if (i > 0)
c = CORD_cat(c, ", ");
void *entry = GET_ENTRY(t, i);
void *entry = GET_ENTRY(*t, i);
c = CORD_cat(c, generic_as_text(entry, colorize, table.key));
c = CORD_cat(c, "=>");
c = CORD_cat(c, generic_as_text(entry + val_off, colorize, table.value));
@ -546,13 +546,13 @@ public table_t Table_from_entries(array_t entries, const TypeInfo *type)
return t;
}
void *Table_str_get(const table_t *t, const char *key)
void *Table_str_get(table_t t, const char *key)
{
void **ret = Table_get(t, &key, &StrToVoidStarTable);
return ret ? *ret : NULL;
}
void *Table_str_get_raw(const table_t *t, const char *key)
void *Table_str_get_raw(table_t t, const char *key)
{
void **ret = Table_get_raw(t, &key, &StrToVoidStarTable);
return ret ? *ret : NULL;
@ -573,7 +573,7 @@ void Table_str_remove(table_t *t, const char *key)
return Table_remove(t, &key, &StrToVoidStarTable);
}
void *Table_str_entry(const table_t *t, int64_t n)
void *Table_str_entry(table_t t, int64_t n)
{
return Table_entry(t, n);
}

View File

@ -18,7 +18,7 @@
$table.default_value = def; \
$table; })
#define $Table_get(table_expr, key_t, val_t, key_expr, info_expr, filename, start, end) ({ \
const table_t *$t = table_expr; key_t $k = key_expr; const TypeInfo* $info = info_expr; \
const table_t $t = table_expr; key_t $k = key_expr; const TypeInfo* $info = info_expr; \
const val_t *$v = Table_get($t, &$k, $info); \
if (__builtin_expect($v == NULL, 0)) \
fail_source(filename, start, end, "The key %r is not in this table\n", generic_as_text(&$k, USE_COLOR, $info->TableInfo.key)); \
@ -38,9 +38,9 @@
}
table_t Table_from_entries(array_t entries, const TypeInfo *type);
void *Table_get(const table_t *t, const void *key, const TypeInfo *type);
void *Table_get_raw(const table_t *t, const void *key, const TypeInfo *type);
void *Table_entry(const table_t *t, int64_t n);
void *Table_get(table_t t, const void *key, const TypeInfo *type);
void *Table_get_raw(table_t t, const void *key, const TypeInfo *type);
void *Table_entry(table_t t, int64_t n);
void *Table_reserve(table_t *t, const void *key, const void *value, const TypeInfo *type);
void Table_set(table_t *t, const void *key, const void *value, const TypeInfo *type);
void Table_remove(table_t *t, const void *key, const TypeInfo *type);
@ -51,14 +51,14 @@ bool Table_equal(const table_t *x, const table_t *y, const TypeInfo *type);
uint32_t Table_hash(const table_t *t, const TypeInfo *type);
CORD Table_as_text(const table_t *t, bool colorize, const TypeInfo *type);
void *Table_str_entry(const table_t *t, int64_t n);
void *Table_str_get(const table_t *t, const char *key);
void *Table_str_get_raw(const table_t *t, const char *key);
void *Table_str_entry(table_t t, int64_t n);
void *Table_str_get(table_t t, const char *key);
void *Table_str_get_raw(table_t t, const char *key);
void Table_str_set(table_t *t, const char *key, const void *value);
void *Table_str_reserve(table_t *t, const char *key, const void *value);
void Table_str_remove(table_t *t, const char *key);
#define Table_length(t) ((t)->entries.length)
#define Table_length(t) ((t).entries.length)
extern const TypeInfo StrToVoidStarTable;

View File

@ -174,7 +174,7 @@ static CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg
for (arg_ast_t *call_arg = call_args; call_arg; call_arg = call_arg->next) {
if (call_arg->name) continue;
const char *pseudoname = heap_strf("%ld", i++);
if (!Table_str_get(&used_args, pseudoname)) {
if (!Table_str_get(used_args, pseudoname)) {
type_t *actual_t = get_type(env, call_arg->value);
if (!can_promote(actual_t, spec_arg->type))
code_err(call_arg->value, "This argument is supposed to be a %T, but this value is a %T", spec_arg->type, actual_t);
@ -199,11 +199,11 @@ static CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg
int64_t i = 1;
for (arg_ast_t *call_arg = call_args; call_arg; call_arg = call_arg->next) {
if (call_arg->name) {
if (!Table_str_get(&used_args, call_arg->name))
if (!Table_str_get(used_args, call_arg->name))
code_err(call_arg->value, "There is no argument with the name '%s'", call_arg->name);
} else {
const char *pseudoname = heap_strf("%ld", i++);
if (!Table_str_get(&used_args, pseudoname))
if (!Table_str_get(used_args, pseudoname))
code_err(call_arg->value, "This is one argument too many!");
}
}
@ -748,9 +748,15 @@ CORD compile(env_t *env, ast_t *ast)
compile_type_info(env, self_value_t), ")");
} else code_err(ast, "There is no '%s' method for arrays", call->name);
}
case TableType: {
goto fncall;
}
// case TableType: {
// if (streq(call->name, "get")) {
// type_t *item_t = Match(self_value_t, ArrayType)->item_type;
// CORD self = compile_to_pointer_depth(env, call->self, 1, false);
// arg_t *arg_spec = new(arg_t, .name="item", .type=Type(PointerType, .pointed=item_t, .is_stack=true, .is_readonly=true),
// .next=new(arg_t, .name="at", .type=Type(IntType, .bits=64), .default_val=FakeAST(Int, .i=0, .bits=64)));
// return CORD_all("Table_get(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ",
// compile_type_info(env, self_value_t), ")");
// }
default: goto fncall;
}
}
@ -1094,9 +1100,9 @@ CORD compile(env_t *env, ast_t *ast)
switch (value_t->tag) {
case TypeInfoType: {
auto info = Match(value_t, TypeInfoType);
table_t *namespace = Table_str_get(env->type_namespaces, info->name);
table_t *namespace = Table_str_get(*env->type_namespaces, info->name);
if (!namespace) code_err(f->fielded, "I couldn't find a namespace for this type");
binding_t *b = Table_str_get(namespace, f->field);
binding_t *b = Table_str_get(*namespace, f->field);
if (!b) code_err(ast, "I couldn't find the field '%s' on this type", f->field);
if (!b->code) code_err(ast, "I couldn't figure out how to compile this field");
return b->code;
@ -1189,7 +1195,7 @@ CORD compile(env_t *env, ast_t *ast)
type_t *value_t = Match(container_t, TableType)->value_type;
if (!can_promote(index_t, key_t))
code_err(indexing->index, "This value has type %T, but this table can only be index with keys of type %T", index_t, key_t);
CORD table = compile_to_pointer_depth(env, indexing->indexed, 1, false);
CORD table = compile_to_pointer_depth(env, indexing->indexed, 0, false);
CORD key = compile(env, indexing->index);
file_t *f = indexing->index->file;
return CORD_all("$Table_get(", table, ", ", compile_type(key_t), ", ", compile_type(value_t), ", ",

View File

@ -49,7 +49,7 @@ static CORD compile_compare_method(env_t *env, ast_t *ast)
"if (diff) return diff;\n"
"switch (x->$tag) {\n");
for (tag_ast_t *tag = def->tags; tag; tag = tag->next) {
type_t *tag_type = Table_str_get(env->types, heap_strf("%s$%s", def->name, tag->name));
type_t *tag_type = Table_str_get(*env->types, heap_strf("%s$%s", def->name, tag->name));
cmp_func = CORD_all(cmp_func, "\tcase $tag$", def->name, "$", tag->name, ": "
"return generic_compare(&x->", tag->name, ", &y->", tag->name, ", ", compile_type_info(env, tag_type), ");\n");
}
@ -66,7 +66,7 @@ static CORD compile_equals_method(env_t *env, ast_t *ast)
"if (x->$tag != y->$tag) return no;\n"
"switch (x->$tag) {\n");
for (tag_ast_t *tag = def->tags; tag; tag = tag->next) {
type_t *tag_type = Table_str_get(env->types, heap_strf("%s$%s", def->name, tag->name));
type_t *tag_type = Table_str_get(*env->types, heap_strf("%s$%s", def->name, tag->name));
eq_func = CORD_all(eq_func, "\tcase $tag$", def->name, "$", tag->name, ": "
"return generic_equal(&x->", tag->name, ", &y->", tag->name, ", ", compile_type_info(env, tag_type), ");\n");
}
@ -82,7 +82,7 @@ static CORD compile_hash_method(env_t *env, ast_t *ast)
"uint32_t hashes[2] = {(uint32_t)obj->$tag};\n"
"switch (obj->$tag) {\n");
for (tag_ast_t *tag = def->tags; tag; tag = tag->next) {
type_t *tag_type = Table_str_get(env->types, heap_strf("%s$%s", def->name, tag->name));
type_t *tag_type = Table_str_get(*env->types, heap_strf("%s$%s", def->name, tag->name));
hash_func = CORD_all(hash_func, "\tcase $tag$", def->name, "$", tag->name, ": "
"hashes[1] = generic_hash(&obj->", tag->name, ", ", compile_type_info(env, tag_type), ");\n"
"break;\n");
@ -135,7 +135,7 @@ void compile_enum_def(env_t *env, ast_t *ast)
enum_def = CORD_cat(enum_def, "};\n};\n");
env->code->typecode = CORD_cat(env->code->typecode, enum_def);
type_t *t = Table_str_get(env->types, def->name);
type_t *t = Table_str_get(*env->types, def->name);
CORD typeinfo = CORD_asprintf("public const TypeInfo %s = {%zu, %zu, {.tag=CustomInfo, .CustomInfo={",
def->name, type_size(t), type_align(t));

View File

@ -196,7 +196,7 @@ env_t *namespace_env(env_t *env, const char *namespace_name)
{
env_t *ns_env = new(env_t);
*ns_env = *env;
ns_env->locals = Table_str_get(env->type_namespaces, namespace_name);
ns_env->locals = Table_str_get(*env->type_namespaces, namespace_name);
if (!ns_env->locals) {
ns_env->locals = new(table_t, .fallback=env->globals);
Table_str_set(env->type_namespaces, namespace_name, ns_env->locals);
@ -207,7 +207,7 @@ env_t *namespace_env(env_t *env, const char *namespace_name)
binding_t *get_binding(env_t *env, const char *name)
{
return Table_str_get(env->locals, name);
return Table_str_get(*env->locals, name);
}
binding_t *get_namespace_binding(env_t *env, ast_t *self, const char *name)
@ -224,11 +224,11 @@ binding_t *get_namespace_binding(env_t *env, ast_t *self, const char *name)
errx(1, "Table methods not implemented");
}
case BoolType: case IntType: case NumType: case TextType: {
table_t *ns = Table_str_get(env->type_namespaces, CORD_to_const_char_star(type_to_cord(cls_type)));
table_t *ns = Table_str_get(*env->type_namespaces, CORD_to_const_char_star(type_to_cord(cls_type)));
if (!ns) {
code_err(self, "No namespace found for this type!");
}
return Table_str_get(ns, name);
return Table_str_get(*ns, name);
}
case TypeInfoType: case StructType: case EnumType: {
const char *type_name;
@ -239,9 +239,9 @@ binding_t *get_namespace_binding(env_t *env, ast_t *self, const char *name)
default: errx(1, "Unreachable");
}
table_t *namespace = Table_str_get(env->type_namespaces, type_name);
table_t *namespace = Table_str_get(*env->type_namespaces, type_name);
if (!namespace) return NULL;
return Table_str_get(namespace, name);
return Table_str_get(*namespace, name);
}
default: break;
}

View File

@ -143,7 +143,7 @@ void compile_struct_def(env_t *env, ast_t *ast)
// Typeinfo:
CORD_appendf(&env->code->typedefs, "extern const TypeInfo %s;\n", def->name);
type_t *t = Table_str_get(env->types, def->name);
type_t *t = Table_str_get(*env->types, def->name);
CORD typeinfo = CORD_asprintf("public const TypeInfo %s = {%zu, %zu, {.tag=CustomInfo, .CustomInfo={",
def->name, type_size(t), type_align(t));

View File

@ -19,7 +19,7 @@ type_t *parse_type_ast(env_t *env, type_ast_t *ast)
switch (ast->tag) {
case VarTypeAST: {
const char *name = Match(ast, VarTypeAST)->name;
type_t *t = Table_str_get(env->types, name);
type_t *t = Table_str_get(*env->types, name);
if (t) return t;
code_err(ast, "I don't know a type with the name '%s'", name);
}
@ -340,9 +340,9 @@ type_t *get_type(env_t *env, ast_t *ast)
type_t *fielded_t = get_type(env, access->fielded);
if (fielded_t->tag == TypeInfoType) {
auto info = Match(fielded_t, TypeInfoType);
table_t *namespace = Table_str_get(env->type_namespaces, info->name);
table_t *namespace = Table_str_get(*env->type_namespaces, info->name);
if (!namespace) code_err(access->fielded, "I couldn't find a namespace for this type");
binding_t *b = Table_str_get(namespace, access->field);
binding_t *b = Table_str_get(*namespace, access->field);
if (!b) code_err(ast, "I couldn't find the field '%s' on this type", access->field);
return b->type;
}