diff options
| author | Bruce Hill <bruce@bruce-hill.com> | 2024-08-10 15:15:38 -0400 |
|---|---|---|
| committer | Bruce Hill <bruce@bruce-hill.com> | 2024-08-10 15:15:38 -0400 |
| commit | 8d3d5913129a8ede381462d5ad5e98f9c789e5c8 (patch) | |
| tree | 074e1fd4489710af0810e2a901106a7161467021 | |
| parent | cb6cebf12e2124503f0551bc1bf6b44f68d86746 (diff) | |
Add Sets to the language
| -rw-r--r-- | ast.c | 5 | ||||
| -rw-r--r-- | ast.h | 10 | ||||
| -rw-r--r-- | builtins/table.c | 109 | ||||
| -rw-r--r-- | builtins/table.h | 30 | ||||
| -rw-r--r-- | builtins/types.c | 4 | ||||
| -rw-r--r-- | builtins/types.h | 2 | ||||
| -rw-r--r-- | compile.c | 182 | ||||
| -rw-r--r-- | environment.c | 10 | ||||
| -rw-r--r-- | parse.c | 61 | ||||
| -rw-r--r-- | test/sets.tm | 36 | ||||
| -rw-r--r-- | test/tables.tm | 24 | ||||
| -rw-r--r-- | typecheck.c | 57 | ||||
| -rw-r--r-- | types.c | 21 | ||||
| -rw-r--r-- | types.h | 4 |
14 files changed, 534 insertions, 21 deletions
@@ -114,6 +114,9 @@ CORD ast_to_xml(ast_t *ast) T(Min, "<Min>%r%r%r</Min>", ast_to_xml(data.lhs), ast_to_xml(data.rhs), optional_tagged("key", data.key)) T(Max, "<Max>%r%r%r</Max>", ast_to_xml(data.lhs), ast_to_xml(data.rhs), optional_tagged("key", data.key)) T(Array, "<Array>%r%r</Array>", optional_tagged_type("item-type", data.type), ast_list_to_xml(data.items)) + T(Set, "<Set>%r%r</Set>", + optional_tagged_type("item-type", data.item_type), + ast_list_to_xml(data.items)) T(Table, "<Table>%r%r%r%r%r</Table>", optional_tagged_type("key-type", data.key_type), optional_tagged_type("value-type", data.value_type), ast_list_to_xml(data.entries), optional_tagged("fallback", data.fallback), @@ -169,6 +172,7 @@ CORD type_ast_to_xml(type_ast_t *t) T(PointerTypeAST, "<PointerType is_optional=\"%s\" is_stack=\"%s\" is_readonly=\"%s\">%r</PointerType>", data.is_optional ? "yes" : "no", data.is_stack ? "yes" : "no", data.is_readonly ? "yes" : "no", type_ast_to_xml(data.pointed)) T(ArrayTypeAST, "<ArrayType>%r</ArrayType>", type_ast_to_xml(data.item)) + T(SetTypeAST, "<TableType>%r</TableType>", type_ast_to_xml(data.item)) T(TableTypeAST, "<TableType>%r %r</TableType>", type_ast_to_xml(data.key), type_ast_to_xml(data.value)) T(FunctionTypeAST, "<FunctionType>%r %r</FunctionType>", arg_list_to_xml(data.args), type_ast_to_xml(data.ret)) #undef T @@ -220,6 +224,7 @@ bool type_ast_eq(type_ast_t *x, type_ast_t *y) && type_ast_eq(x_info->pointed, y_info->pointed)); } case ArrayTypeAST: return type_ast_eq(Match(x, ArrayTypeAST)->item, Match(y, ArrayTypeAST)->item); + case SetTypeAST: return type_ast_eq(Match(x, SetTypeAST)->item, Match(y, SetTypeAST)->item); case TableTypeAST: { auto tx = Match(x, TableTypeAST); auto ty = Match(y, TableTypeAST); @@ -57,6 +57,7 @@ typedef enum { VarTypeAST, PointerTypeAST, ArrayTypeAST, + SetTypeAST, TableTypeAST, FunctionTypeAST, } type_ast_e; @@ -89,6 +90,9 @@ struct type_ast_s { type_ast_t *key, *value; } TableTypeAST; struct { + type_ast_t *item; + } SetTypeAST; + struct { arg_ast_t *args; type_ast_t *ret; } FunctionTypeAST; @@ -104,7 +108,7 @@ typedef enum { BinaryOp, UpdateAssign, Length, Not, Negative, HeapAllocate, StackReference, Min, Max, - Array, Table, TableEntry, Comprehension, + Array, Set, Table, TableEntry, Comprehension, FunctionDef, Lambda, FunctionCall, MethodCall, Block, @@ -178,6 +182,10 @@ struct ast_s { ast_list_t *items; } Array; struct { + type_ast_t *item_type; + ast_list_t *items; + } Set; + struct { type_ast_t *key_type, *value_type; ast_t *fallback, *default_value; ast_list_t *entries; diff --git a/builtins/table.c b/builtins/table.c index 16b06f6f..5329ec24 100644 --- a/builtins/table.c +++ b/builtins/table.c @@ -472,8 +472,12 @@ public CORD Table$as_text(const table_t *t, bool colorize, const TypeInfo *type) assert(type->tag == TableInfo); auto table = type->TableInfo; - if (!t) - return CORD_all("{", generic_as_text(NULL, false, table.key), ":", generic_as_text(NULL, false, table.value), "}"); + if (!t) { + if (table.value != &$Void) + return CORD_all("{", generic_as_text(NULL, false, table.key), ":", generic_as_text(NULL, false, table.value), "}"); + else + return CORD_all("{", generic_as_text(NULL, false, table.key), "}"); + } int64_t val_off = value_offset(type); CORD c = "{"; @@ -482,8 +486,8 @@ public CORD Table$as_text(const table_t *t, bool colorize, const TypeInfo *type) c = CORD_cat(c, ", "); void *entry = GET_ENTRY(*t, i); c = CORD_cat(c, generic_as_text(entry, colorize, table.key)); - c = CORD_cat(c, ":"); - c = CORD_cat(c, generic_as_text(entry + val_off, colorize, table.value)); + if (table.value != &$Void) + c = CORD_all(c, ":", generic_as_text(entry + val_off, colorize, table.value)); } if (t->fallback) { @@ -522,6 +526,103 @@ public table_t Table$from_entries(array_t entries, const TypeInfo *type) return t; } +// Overlap is "set intersection" in formal terms +public table_t Table$overlap(table_t a, table_t b, const TypeInfo *type) +{ + // Return a table such that t[k]==a[k] for all k such that a:has(k), b:has(k), and a[k]==b[k] + table_t result = {}; + const size_t offset = value_offset(type); + for (int64_t i = 0; i < Table$length(a); i++) { + void *key = GET_ENTRY(a, i); + void *a_value = key + offset; + void *b_value = Table$get(b, key, type); + if (b_value && generic_equal(a_value, b_value, type->TableInfo.value)) + Table$set(&result, key, a_value, type); + } + + if (a.fallback) { + result.fallback = new(table_t); + *result.fallback = Table$overlap(*a.fallback, b, type); + } + + if (a.default_value && b.default_value && generic_equal(a.default_value, b.default_value, type->TableInfo.value)) + result.default_value = a.default_value; + + return result; +} + +// With is "set union" in formal terms +public table_t Table$with(table_t a, table_t b, const TypeInfo *type) +{ + // return a table such that t[k]==b[k] for all k such that b:has(k), and t[k]==a[k] for all k such that a:has(k) and not b:has(k) + table_t result = {}; + const size_t offset = value_offset(type); + for (int64_t i = 0; i < Table$length(a); i++) { + void *key = GET_ENTRY(a, i); + Table$set(&result, key, key + offset, type); + } + for (int64_t i = 0; i < Table$length(b); i++) { + void *key = GET_ENTRY(b, i); + Table$set(&result, key, key + offset, type); + } + + if (a.fallback && b.fallback) { + result.fallback = new(table_t); + *result.fallback = Table$with(*a.fallback, *b.fallback, type); + } else { + result.fallback = a.fallback ? a.fallback : b.fallback; + } + + // B's default value takes precedence over A's + result.default_value = b.default_value ? b.default_value : a.default_value; + + return result; +} + +// Without is "set difference" in formal terms +public table_t Table$without(table_t a, table_t b, const TypeInfo *type) +{ + // Return a table such that t[k]==a[k] for all k such that not b:has(k) or b[k] != a[k] + table_t result = {}; + const size_t offset = value_offset(type); + for (int64_t i = 0; i < Table$length(a); i++) { + void *key = GET_ENTRY(a, i); + void *a_value = key + offset; + void *b_value = Table$get(b, key, type); + if (!b_value || !generic_equal(a_value, b_value, type->TableInfo.value)) + Table$set(&result, key, a_value, type); + } + + if (a.fallback) { + result.fallback = new(table_t); + *result.fallback = Table$without(*a.fallback, b, type); + } + + if (a.default_value) { + if (!b.default_value || !generic_equal(a.default_value, b.default_value, type->TableInfo.value)) + result.default_value = a.default_value; + } + + return result; +} + +public bool Table$is_subset_of(table_t a, table_t b, bool strict, const TypeInfo *type) +{ + if (a.entries.length > b.entries.length || (strict && a.entries.length == b.entries.length)) + return false; + + for (int64_t i = 0; i < Table$length(a); i++) { + void *found = Table$get_raw(b, GET_ENTRY(a, i), type); + if (!found) return false; + } + return true; +} + +public bool Table$is_superset_of(table_t a, table_t b, bool strict, const TypeInfo *type) +{ + return Table$is_subset_of(b, a, strict, type); +} + public void *Table$str_get(table_t t, const char *key) { void **ret = Table$get(t, &key, &StrToVoidStarTable); diff --git a/builtins/table.h b/builtins/table.h index 36f6d75e..3b1c7f98 100644 --- a/builtins/table.h +++ b/builtins/table.h @@ -21,15 +21,30 @@ table.fallback = fb; \ table.default_value = def; \ table; }) -#define Table_get(table_expr, key_t, val_t, key_expr, info_expr, filename, start, end) ({ \ +#define Set(item_t, item_info, N, ...) ({ \ + item_t ents[N] = {__VA_ARGS__}; \ + table_t set = Table$from_entries((array_t){ \ + .data=memcpy(GC_MALLOC(sizeof(ents)), ents, sizeof(ents)), \ + .length=sizeof(ents)/sizeof(ents[0]), \ + .stride=(void*)&ents[1] - (void*)&ents[0], \ + }, $SetInfo(item_info)); \ + set; }) + +table_t Table$from_entries(array_t entries, const TypeInfo *type); +void *Table$get(table_t t, const void *key, const TypeInfo *type); +#define Table$get_value_or_fail(table_expr, key_t, val_t, key_expr, info_expr, filename, start, end) ({ \ const table_t t = table_expr; key_t k = key_expr; const TypeInfo* info = info_expr; \ const val_t *v = Table$get(t, &k, info); \ if (__builtin_expect(v == NULL, 0)) \ fail_source(filename, start, end, "The key %r is not in this table\n", generic_as_text(&k, no, info->TableInfo.key)); \ *v; }) - -table_t Table$from_entries(array_t entries, const TypeInfo *type); -void *Table$get(table_t t, const void *key, const TypeInfo *type); +#define Table$get_value_or_default(table_expr, key_t, val_t, key_expr, default_val, info_expr) ({ \ + const table_t t = table_expr; const key_t k = key_expr; \ + const val_t *v = Table$get(t, &k, info_expr); \ + v ? *v : default_val; }) +#define Table$has_value(table_expr, key_expr, info_expr) ({ \ + const table_t t = table_expr; __typeof(key_expr) k = key_expr; \ + (Table$get(t, &k, info_expr) != NULL); }) void *Table$get_raw(table_t t, const void *key, const TypeInfo *type); void *Table$entry(table_t t, int64_t n); void *Table$reserve(table_t *t, const void *key, const void *value, const TypeInfo *type); @@ -39,6 +54,13 @@ void Table$set(table_t *t, const void *key, const void *value, const TypeInfo *t #define Table$reserve_value(t, key_expr, type) ({ __typeof(key_expr) k = key_expr; Table$reserve(t, &k, NULL, type); }) void Table$remove(table_t *t, const void *key, const TypeInfo *type); #define Table$remove_value(t, key_expr, type) ({ __typeof(key_expr) k = key_expr; Table$remove(t, &k, type); }) + +table_t Table$overlap(table_t a, table_t b, const TypeInfo *type); +table_t Table$with(table_t a, table_t b, const TypeInfo *type); +table_t Table$without(table_t a, table_t b, const TypeInfo *type); +bool Table$is_subset_of(table_t a, table_t b, bool strict, const TypeInfo *type); +bool Table$is_superset_of(table_t a, table_t b, bool strict, const TypeInfo *type); + void Table$clear(table_t *t); table_t Table$sorted(table_t t, const TypeInfo *type); void Table$mark_copy_on_write(table_t *t); diff --git a/builtins/types.c b/builtins/types.c index 355473e9..4fb2c523 100644 --- a/builtins/types.c +++ b/builtins/types.c @@ -29,8 +29,8 @@ public const TypeInfo $TypeInfo = { .TypeInfoInfo.type_str="TypeInfo", }; -public const TypeInfo $Void = {.size=0, .align=0}; -public const TypeInfo $Abort = {.size=0, .align=0}; +public const TypeInfo $Void = {.size=0, .align=0, .tag=EmptyStruct}; +public const TypeInfo $Abort = {.size=0, .align=0, .tag=EmptyStruct}; public CORD Func$as_text(const void *fn, bool colorize, const TypeInfo *type) { diff --git a/builtins/types.h b/builtins/types.h index 864a55c2..4184e8fa 100644 --- a/builtins/types.h +++ b/builtins/types.h @@ -58,6 +58,8 @@ typedef struct TypeInfo { .tag=PointerInfo, .PointerInfo={.sigil=sigil_expr, .pointed=pointed_info, .is_optional=opt}}) #define $ArrayInfo(item_info) &((TypeInfo){.size=sizeof(array_t), .align=__alignof__(array_t), \ .tag=ArrayInfo, .ArrayInfo.item=item_info}) +#define $SetInfo(item_info) &((TypeInfo){.size=sizeof(table_t), .align=__alignof__(table_t), \ + .tag=TableInfo, .TableInfo.key=item_info, .TableInfo.value=&$Void}) #define $TableInfo(key_expr, value_expr) &((TypeInfo){.size=sizeof(table_t), .align=__alignof__(table_t), \ .tag=TableInfo, .TableInfo.key=key_expr, .TableInfo.value=value_expr}) #define $FunctionInfo(typestr) &((TypeInfo){.size=sizeof(void*), .align=__alignof__(void*), \ @@ -55,6 +55,13 @@ static bool promote(env_t *env, CORD *code, type_t *actual, type_t *needed) return true; } + // Set -> Array promotion: + if (needed->tag == ArrayType && actual->tag == SetType + && type_eq(Match(needed, ArrayType)->item_type, Match(actual, SetType)->item_type)) { + *code = CORD_all("(", *code, ").entries"); + return true; + } + return false; } @@ -126,6 +133,7 @@ CORD compile_type(type_t *t) return text->lang ? CORD_all(namespace_prefix(text->env->libname, text->env->namespace->parent), text->lang, "_t") : "Text_t"; } case ArrayType: return "array_t"; + case SetType: return "table_t"; case TableType: return "table_t"; case FunctionType: { auto fn = Match(t, FunctionType); @@ -877,7 +885,7 @@ CORD compile_statement(env_t *env, ast_t *ast) } CORD loop = CORD_all("ARRAY_INCREF(", array_code, ");\n", for_code, "{\n", - compile_type(item_t), " ", value, + compile_declaration(item_t, value), " = *(", compile_type(item_t), "*)(", array_code, ".data + (",index,"-1)*", array_code, ".stride);\n", body, "\n}"); @@ -888,6 +896,28 @@ CORD compile_statement(env_t *env, ast_t *ast) loop = CORD_all("{\narray_t ",array_code," = ", compile(env, array), ";\n", loop, "\n}"); return loop; } + case SetType: { + type_t *item_type = Match(iter_t, SetType)->item_type; + + CORD set = is_idempotent(for_->iter) ? compile(env, for_->iter) : "set"; + CORD loop = CORD_all("ARRAY_INCREF(", set, ".entries);\n" + "for (int64_t i = 0; i < ",set,".entries.length; ++i) {\n"); + + if (for_->vars) { + if (for_->vars->next) + code_err(for_->vars->next->ast, "This is too many variables for this loop"); + CORD item = compile(env, for_->vars->ast); + loop = CORD_all(loop, compile_declaration(item_type, item), " = *(", compile_type(item_type), "*)(", + set,".entries.data + i*", set, ".entries.stride);\n"); + } + loop = CORD_all(loop, body, "\n}"); + if (for_->empty) + loop = CORD_all("if (", set, ".entries.length > 0) {\n", loop, "\n} else ", compile_statement(env, for_->empty)); + loop = CORD_all(loop, stop, "\nARRAY_DECREF(", set, ".entries);\n"); + if (!is_idempotent(for_->iter)) + loop = CORD_all("{\ntable_t ",set," = ", compile(env, for_->iter), ";\n", loop, "\n}"); + return loop; + } case TableType: { type_t *key_t = Match(iter_t, TableType)->key_type; type_t *value_t = Match(iter_t, TableType)->value_type; @@ -910,14 +940,14 @@ CORD compile_statement(env_t *env, ast_t *ast) } if (key) { - loop = CORD_all(loop, compile_type(key_t), " ", key, " = *(", compile_type(key_t), "*)(", + loop = CORD_all(loop, compile_declaration(key_t, key), " = *(", compile_type(key_t), "*)(", table,".entries.data + i*", table, ".entries.stride);\n"); } if (value) { size_t value_offset = type_size(key_t); if (type_align(value_t) > 1 && value_offset % type_align(value_t)) value_offset += type_align(value_t) - (value_offset % type_align(value_t)); // padding - loop = CORD_all(loop, compile_type(value_t), " ", value, " = *(", compile_type(value_t), "*)(", + loop = CORD_all(loop, compile_declaration(value_t, value), " = *(", compile_type(value_t), "*)(", table,".entries.data + i*", table, ".entries.stride + ", heap_strf("%zu", value_offset), ");\n"); } loop = CORD_all(loop, body, "\n}"); @@ -1081,6 +1111,7 @@ CORD expr_as_text(env_t *env, CORD expr, type_t *t, CORD color) return CORD_asprintf("Text$as_text(stack(%r), %r, %r)", expr, color, compile_type_info(env, t)); } case ArrayType: return CORD_asprintf("Array$as_text(stack(%r), %r, %r)", expr, color, compile_type_info(env, t)); + case SetType: return CORD_asprintf("Table$as_text(stack(%r), %r, %r)", expr, color, compile_type_info(env, t)); case TableType: return CORD_asprintf("Table$as_text(stack(%r), %r, %r)", expr, color, compile_type_info(env, t)); case FunctionType: case ClosureType: return CORD_asprintf("Func$as_text(stack(%r), %r, %r)", expr, color, compile_type_info(env, t)); case PointerType: return CORD_asprintf("Pointer$as_text(stack(%r), %r, %r)", expr, color, compile_type_info(env, t)); @@ -1773,6 +1804,58 @@ CORD compile(env_t *env, ast_t *ast) } } + case Set: { + auto set = Match(ast, Set); + if (!set->items) + return "((table_t){})"; + + type_t *set_type = get_type(env, ast); + type_t *item_type = Match(set_type, SetType)->item_type; + + for (ast_list_t *item = set->items; item; item = item->next) { + if (item->ast->tag == Comprehension) + goto set_comprehension; + } + + { // No comprehension: + CORD code = CORD_all("Set(", + compile_type(item_type), ", ", + compile_type_info(env, item_type)); + + size_t n = 0; + for (ast_list_t *item = set->items; item; item = item->next) + ++n; + CORD_appendf(&code, ", %zu", n); + + for (ast_list_t *item = set->items; item; item = item->next) { + code = CORD_all(code, ",\n\t", compile(env, item->ast)); + } + return CORD_cat(code, ")"); + } + + set_comprehension: + { + static int64_t comp_num = 1; + env_t *scope = fresh_scope(env); + scope->comprehension_var = heap_strf("set$%ld", comp_num++); + + CORD code = CORD_all("({ table_t ", scope->comprehension_var, " = {};"); + set_binding(scope, scope->comprehension_var, new(binding_t, .type=set_type, .code=scope->comprehension_var)); + for (ast_list_t *item = set->items; item; item = item->next) { + if (item->ast->tag == Comprehension) { + code = CORD_all(code, "\n", compile_statement(scope, item->ast)); + } else { + CORD add_item = compile_statement( + scope, WrapAST(item->ast, MethodCall, .name="add", .self=FakeAST(StackReference, FakeAST(Var, scope->comprehension_var)), + .args=new(arg_ast_t, .value=item->ast))); + code = CORD_all(code, "\n", add_item); + } + } + code = CORD_all(code, " ", scope->comprehension_var, "; })"); + return code; + } + + } case Comprehension: { ast_t *base = Match(ast, Comprehension)->expr; while (base->tag == Comprehension) @@ -1961,12 +2044,81 @@ CORD compile(env_t *env, ast_t *ast) return CORD_all("Array$reversed(", self, ", ", padded_item_size, ")"); } else code_err(ast, "There is no '%s' method for arrays", call->name); } + case SetType: { + auto set = Match(self_value_t, SetType); + if (streq(call->name, "has")) { + CORD self = compile_to_pointer_depth(env, call->self, 0, false); + arg_t *arg_spec = new(arg_t, .name="key", .type=set->item_type); + return CORD_all("Table$has_value(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ", + compile_type_info(env, self_value_t), ")"); + } else if (streq(call->name, "add")) { + CORD self = compile_to_pointer_depth(env, call->self, 1, false); + arg_t *arg_spec = new(arg_t, .name="item", .type=set->item_type); + return CORD_all("Table$set_value(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", NULL, ", + compile_type_info(env, self_value_t), ")"); + } else if (streq(call->name, "add_all")) { + arg_t *arg_spec = new(arg_t, .name="items", .type=Type(ArrayType, .item_type=Match(self_value_t, SetType)->item_type)); + return CORD_all("({ table_t *set = ", compile_to_pointer_depth(env, call->self, 1, false), "; ", + "array_t to_add = ", compile_arguments(env, ast, arg_spec, call->args), "; ", + "for (int64_t i = 0; i < to_add.length; i++)\n" + "Table$set(set, to_add.data + i*to_add.stride, NULL, ", compile_type_info(env, self_value_t), ");\n", + "(void)0; })"); + } else if (streq(call->name, "remove")) { + CORD self = compile_to_pointer_depth(env, call->self, 1, false); + arg_t *arg_spec = new(arg_t, .name="item", .type=set->item_type); + return CORD_all("Table$remove_value(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ", + compile_type_info(env, self_value_t), ")"); + } else if (streq(call->name, "remove_all")) { + arg_t *arg_spec = new(arg_t, .name="items", .type=Type(ArrayType, .item_type=Match(self_value_t, SetType)->item_type)); + return CORD_all("({ table_t *set = ", compile_to_pointer_depth(env, call->self, 1, false), "; ", + "array_t to_add = ", compile_arguments(env, ast, arg_spec, call->args), "; ", + "for (int64_t i = 0; i < to_add.length; i++)\n" + "Table$remove(set, to_add.data + i*to_add.stride, ", compile_type_info(env, self_value_t), ");\n", + "(void)0; })"); + } else if (streq(call->name, "clear")) { + CORD self = compile_to_pointer_depth(env, call->self, 1, false); + (void)compile_arguments(env, ast, NULL, call->args); + return CORD_all("Table$clear(", self, ")"); + } else if (streq(call->name, "with")) { + CORD self = compile_to_pointer_depth(env, call->self, 0, false); + arg_t *arg_spec = new(arg_t, .name="other", .type=self_value_t); + return CORD_all("Table$with(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), + ", ", compile_type_info(env, self_value_t), ")"); + } else if (streq(call->name, "overlap")) { + CORD self = compile_to_pointer_depth(env, call->self, 0, false); + arg_t *arg_spec = new(arg_t, .name="other", .type=self_value_t); + return CORD_all("Table$overlap(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), + ", ", compile_type_info(env, self_value_t), ")"); + } else if (streq(call->name, "without")) { + CORD self = compile_to_pointer_depth(env, call->self, 0, false); + arg_t *arg_spec = new(arg_t, .name="other", .type=self_value_t); + return CORD_all("Table$without(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), + ", ", compile_type_info(env, self_value_t), ")"); + } else if (streq(call->name, "is_subset_of")) { + CORD self = compile_to_pointer_depth(env, call->self, 0, false); + arg_t *arg_spec = new(arg_t, .name="other", .type=self_value_t, + .next=new(arg_t, .name="strict", .type=Type(BoolType), .default_val=FakeAST(Bool, false))); + return CORD_all("Table$is_subset_of(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), + ", ", compile_type_info(env, self_value_t), ")"); + } else if (streq(call->name, "is_superset_of")) { + CORD self = compile_to_pointer_depth(env, call->self, 0, false); + arg_t *arg_spec = new(arg_t, .name="other", .type=self_value_t, + .next=new(arg_t, .name="strict", .type=Type(BoolType), .default_val=FakeAST(Bool, false))); + return CORD_all("Table$is_superset_of(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), + ", ", compile_type_info(env, self_value_t), ")"); + } else code_err(ast, "There is no '%s' method for tables", call->name); + } case TableType: { auto table = Match(self_value_t, TableType); if (streq(call->name, "get")) { CORD self = compile_to_pointer_depth(env, call->self, 0, false); - arg_t *arg_spec = new(arg_t, .name="key", .type=Type(PointerType, .pointed=table->key_type, .is_stack=true, .is_readonly=true)); - return CORD_all("Table$get(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ", + arg_t *arg_spec = new(arg_t, .name="key", .type=table->key_type, .next=new(arg_t, .name="default", .type=table->value_type)); + return CORD_all("Table$get_value_or_default(", self, ", ", compile_type(table->key_type), ", ", compile_type(table->value_type), ", ", + compile_arguments(env, ast, arg_spec, call->args), ", ", compile_type_info(env, self_value_t), ")"); + } else if (streq(call->name, "has")) { + CORD self = compile_to_pointer_depth(env, call->self, 0, false); + arg_t *arg_spec = new(arg_t, .name="key", .type=table->key_type); + return CORD_all("Table$has_value(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ", compile_type_info(env, self_value_t), ")"); } else if (streq(call->name, "set")) { CORD self = compile_to_pointer_depth(env, call->self, 1, false); @@ -2147,6 +2299,16 @@ CORD compile(env_t *env, ast_t *ast) } code_err(ast, "The field '%s' is not a valid field name of %T", f->field, value_t); } + case SetType: { + if (streq(f->field, "items")) { + return CORD_all("({ table_t *t = ", compile_to_pointer_depth(env, f->fielded, 1, false), ";\n" + "ARRAY_INCREF(t->entries);\n" + "t->entries; })"); + } else if (streq(f->field, "fallback")) { + return CORD_all("(", compile_to_pointer_depth(env, f->fielded, 0, false), ").fallback"); + } + code_err(ast, "There is no '%s' field on sets", f->field); + } case TableType: { if (streq(f->field, "keys")) { return CORD_all("({ table_t *t = ", compile_to_pointer_depth(env, f->fielded, 1, false), ";\n" @@ -2219,7 +2381,7 @@ CORD compile(env_t *env, ast_t *ast) if (!promote(env, &key, index_t, key_t)) code_err(indexing->index, "This value has type %T, but this table can only be index with keys of type %T", index_t, key_t); file_t *f = indexing->index->file; - return CORD_all("Table_get(", table, ", ", compile_type(key_t), ", ", compile_type(value_t), ", ", + return CORD_all("Table$get_value_or_fail(", table, ", ", compile_type(key_t), ", ", compile_type(value_t), ", ", key, ", ", compile_type_info(env, container_t), ", ", Text$quoted(f->filename, false), ", ", CORD_asprintf("%ld", (int64_t)(indexing->index->start - f->text)), ", ", CORD_asprintf("%ld", (int64_t)(indexing->index->end - f->text)), @@ -2309,12 +2471,16 @@ CORD compile_type_info(env_t *env, type_t *t) } case ArrayType: { type_t *item_t = Match(t, ArrayType)->item_type; - return CORD_asprintf("$ArrayInfo(%r)", compile_type_info(env, item_t)); + return CORD_all("$ArrayInfo(", compile_type_info(env, item_t), ")"); + } + case SetType: { + type_t *item_type = Match(t, SetType)->item_type; + return CORD_all("$SetInfo(", compile_type_info(env, item_type), ")"); } case TableType: { type_t *key_type = Match(t, TableType)->key_type; type_t *value_type = Match(t, TableType)->value_type; - return CORD_asprintf("$TableInfo(%r, %r)", compile_type_info(env, key_type), compile_type_info(env, value_type)); + return CORD_all("$TableInfo(", compile_type_info(env, key_type), ", ", compile_type_info(env, value_type), ")"); } case PointerType: { auto ptr = Match(t, PointerType); diff --git a/environment.c b/environment.c index 69ad9b66..84f6880e 100644 --- a/environment.c +++ b/environment.c @@ -333,6 +333,16 @@ env_t *for_scope(env_t *env, ast_t *ast) } return scope; } + case SetType: { + if (for_->vars) { + if (for_->vars->next) + code_err(for_->vars->next->ast, "This is too many variables for this loop"); + type_t *item_type = Match(iter_t, SetType)->item_type; + const char *name = Match(for_->vars->ast, Var)->name; + set_binding(scope, name, new(binding_t, .type=item_type, .code=CORD_cat("$", name))); + } + return scope; + } case TableType: { const char *vars[2] = {}; int64_t num_vars = 0; @@ -513,6 +513,19 @@ type_ast_t *parse_table_type(parse_ctx_t *ctx, const char *pos) { return NewTypeAST(ctx->file, start, pos, TableTypeAST, .key=key_type, .value=value_type); } +type_ast_t *parse_set_type(parse_ctx_t *ctx, const char *pos) { + const char *start = pos; + if (!match(&pos, "{")) return NULL; + whitespace(&pos); + type_ast_t *item_type = parse_type(ctx, pos); + if (!item_type) return NULL; + pos = item_type->end; + whitespace(&pos); + if (match(&pos, ":")) return NULL; + expect_closing(ctx, &pos, "}", "I wasn't able to parse the rest of this set type"); + return NewTypeAST(ctx->file, start, pos, SetTypeAST, .item=item_type); +} + type_ast_t *parse_func_type(parse_ctx_t *ctx, const char *pos) { const char *start = pos; if (!match_word(&pos, "func")) return NULL; @@ -576,6 +589,7 @@ type_ast_t *parse_type(parse_ctx_t *ctx, const char *pos) { || (type=parse_pointer_type(ctx, pos)) || (type=parse_array_type(ctx, pos)) || (type=parse_table_type(ctx, pos)) + || (type=parse_set_type(ctx, pos)) || (type=parse_type_name(ctx, pos)) || (type=parse_func_type(ctx, pos)) ); @@ -698,7 +712,7 @@ PARSER(parse_table) { key_type = expect(ctx, pos-1, &pos, parse_type, "I couldn't parse a key type for this table"); whitespace(&pos); if (!match(&pos, ":")) - parser_err(ctx, pos, pos, "I expected an ':' for this table type"); + return NULL; value_type = expect(ctx, pos-1, &pos, parse_type, "I couldn't parse a value type for this table"); whitespace(&pos); } @@ -760,6 +774,50 @@ PARSER(parse_table) { return NewAST(ctx->file, start, pos, Table, .key_type=key_type, .value_type=value_type, .entries=entries, .fallback=fallback, .default_value=default_val); } +PARSER(parse_set) { + const char *start = pos; + if (!match(&pos, "{")) return NULL; + + whitespace(&pos); + + ast_list_t *items = NULL; + type_ast_t *item_type = NULL; + if (match(&pos, ":")) { + whitespace(&pos); + item_type = expect(ctx, pos-1, &pos, parse_type, "I couldn't parse a key type for this set"); + whitespace(&pos); + if (match(&pos, ":")) + return NULL; + whitespace(&pos); + } + + for (;;) { + ast_t *item = optional(ctx, &pos, parse_extended_expr); + if (!item) break; + whitespace(&pos); + if (match(&pos, ":")) return NULL; + ast_t *suffixed = parse_comprehension_suffix(ctx, item); + while (suffixed) { + item = suffixed; + pos = suffixed->end; + suffixed = parse_comprehension_suffix(ctx, item); + } + items = new(ast_list_t, .ast=item, .next=items); + if (!match_separator(&pos)) + break; + } + + REVERSE_LIST(items); + + if (!item_type && !items) + return NULL; + + whitespace(&pos); + expect_closing(ctx, &pos, "}", "I wasn't able to parse the rest of this set"); + + return NewAST(ctx->file, start, pos, Set, .item_type=item_type, .items=items); +} + ast_t *parse_field_suffix(parse_ctx_t *ctx, ast_t *lhs) { if (!lhs) return NULL; const char *pos = lhs->end; @@ -1287,6 +1345,7 @@ PARSER(parse_term_no_suffix) { || (term=parse_lambda(ctx, pos)) || (term=parse_parens(ctx, pos)) || (term=parse_table(ctx, pos)) + || (term=parse_set(ctx, pos)) || (term=parse_var(ctx, pos)) || (term=parse_array(ctx, pos)) || (term=parse_reduction(ctx, pos)) diff --git a/test/sets.tm b/test/sets.tm new file mode 100644 index 00000000..bfec068e --- /dev/null +++ b/test/sets.tm @@ -0,0 +1,36 @@ + +func main(): + >> t1 := {10, 20, 30, 10} + = {10, 20, 30} + >> t1:has(10) + = yes + >> t1:has(-999) + = no + + >> t2 := {30, 40} + + >> t1:with(t2) + >> {10, 20, 30, 40} + + >> t1:without(t2) + >> {10, 20} + + >> t1:overlap(t2) + >> {30} + + + >> {1,2}:is_subset_of({2,3}) + = no + >> {1,2}:is_subset_of({1,2,3}) + = yes + >> {1,2}:is_subset_of({1,2}) + = yes + >> {1,2}:is_subset_of({1,2}, strict=yes) + = no + + >> t1:add_all(t2) + >> t1 + = {10, 20, 30, 40} + >> t1:remove_all(t2) + >> t1 + = {10, 20} diff --git a/test/tables.tm b/test/tables.tm index 132e4301..6dbc1b78 100644 --- a/test/tables.tm +++ b/test/tables.tm @@ -61,3 +61,27 @@ func main(): >> t3:remove(3) >> t3 = {1:10, 2:20} + + do: + >> plain := {1:10, 2:20, 3:30} + >> plain:get(2, -999) + = 20 + >> plain:get(456, -999) + = -999 + >> plain:has(2) + = yes + >> plain:has(456) + = no + + >> fallback := {4:40; fallback=plain} + >> fallback:has(1) + = yes + >> fallback:get(1, -999) + = 10 + + >> default := {5:50; default=0} + >> default:has(28273) + = yes + >> default:get(28273, -999) + = 0 + diff --git a/typecheck.c b/typecheck.c index 10c5820f..ce8b07ed 100644 --- a/typecheck.c +++ b/typecheck.c @@ -54,6 +54,17 @@ type_t *parse_type_ast(env_t *env, type_ast_t *ast) padded_type_size(item_t), ARRAY_MAX_STRIDE); return Type(ArrayType, .item_type=item_t); } + case SetTypeAST: { + type_ast_t *item_type = Match(ast, SetTypeAST)->item; + type_t *item_t = parse_type_ast(env, item_type); + if (!item_t) code_err(item_type, "I can't figure out what this type is."); + if (has_stack_memory(item_t)) + code_err(item_type, "Sets can't have stack references because the array may outlive the stack frame."); + if (padded_type_size(item_t) > ARRAY_MAX_STRIDE) + code_err(ast, "This set holds items that take up %ld bytes, but the maximum supported size is %ld bytes. Consider using an set of pointers instead.", + padded_type_size(item_t), ARRAY_MAX_STRIDE); + return Type(SetType, .item_type=item_t); + } case TableTypeAST: { type_ast_t *key_type_ast = Match(ast, TableTypeAST)->key; type_t *key_type = parse_type_ast(env, key_type_ast); @@ -531,6 +542,35 @@ type_t *get_type(env_t *env, ast_t *ast) code_err(ast, "Arrays cannot hold stack references, because the array may outlive the stack frame the reference was created in."); return Type(ArrayType, .item_type=item_type); } + case Set: { + auto set = Match(ast, Set); + type_t *item_type = NULL; + if (set->item_type) { + item_type = parse_type_ast(env, set->item_type); + } else { + for (ast_list_t *item = set->items; item; item = item->next) { + ast_t *item_ast = item->ast; + env_t *scope = env; + while (item_ast->tag == Comprehension) { + auto comp = Match(item_ast, Comprehension); + scope = for_scope( + scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars)); + item_ast = comp->expr; + } + + type_t *this_item_type = get_type(scope, item_ast); + type_t *item_merged = type_or_type(item_type, this_item_type); + if (!item_merged) + code_err(item_ast, + "This set item has type %T, which is different from earlier set items which have type %T", + this_item_type, item_type); + item_type = item_merged; + } + } + if (has_stack_memory(item_type)) + code_err(ast, "Sets cannot hold stack references because the set may outlive the reference's stack frame."); + return Type(SetType, .item_type=item_type); + } case Table: { auto table = Match(ast, Table); type_t *key_type = NULL, *value_type = NULL; @@ -682,9 +722,24 @@ type_t *get_type(env_t *env, ast_t *ast) else if (streq(call->name, "heap_pop")) return Match(self_value_t, ArrayType)->item_type; else code_err(ast, "There is no '%s' method for arrays", call->name); } + case SetType: { + if (streq(call->name, "add")) return Type(VoidType); + else if (streq(call->name, "has")) return Type(BoolType); + else if (streq(call->name, "add_all")) return Type(VoidType); + else if (streq(call->name, "remove")) return Type(VoidType); + else if (streq(call->name, "remove_all")) return Type(VoidType); + else if (streq(call->name, "clear")) return Type(VoidType); + else if (streq(call->name, "with")) return self_value_t; + else if (streq(call->name, "overlap")) return self_value_t; + else if (streq(call->name, "without")) return self_value_t; + else if (streq(call->name, "is_subset_of")) return Type(BoolType); + else if (streq(call->name, "is_superset_of")) return Type(BoolType); + else code_err(ast, "There is no '%s' method for sets", call->name); + } case TableType: { auto table = Match(self_value_t, TableType); - if (streq(call->name, "get")) return Type(PointerType, .pointed=table->value_type, .is_readonly=true, .is_optional=true); + if (streq(call->name, "get")) return table->value_type; + else if (streq(call->name, "has")) return Type(BoolType); else if (streq(call->name, "set")) return Type(VoidType); else if (streq(call->name, "remove")) return Type(VoidType); else if (streq(call->name, "clear")) return Type(VoidType); @@ -29,6 +29,10 @@ CORD type_to_cord(type_t *t) { auto table = Match(t, TableType); return CORD_asprintf("{%r:%r}", type_to_cord(table->key_type), type_to_cord(table->value_type)); } + case SetType: { + auto set = Match(t, SetType); + return CORD_asprintf("{%r}", type_to_cord(set->item_type)); + } case ClosureType: { return type_to_cord(Match(t, ClosureType)->fn); } @@ -202,6 +206,7 @@ bool has_heap_memory(type_t *t) switch (t->tag) { case ArrayType: return true; case TableType: return true; + case SetType: return true; case PointerType: return true; case StructType: { for (arg_t *field = Match(t, StructType)->fields; field; field = field->next) { @@ -294,6 +299,11 @@ bool can_promote(type_t *actual, type_t *needed) && can_promote(actual_ret, needed_ret))); } + // Set -> Array promotion + if (needed->tag == ArrayType && actual->tag == SetType + && type_eq(Match(needed, ArrayType)->item_type, Match(actual, SetType)->item_type)) + return true; + return false; } @@ -329,6 +339,7 @@ static bool _can_have_cycles(type_t *t, table_t *seen) auto table = Match(t, TableType); return _can_have_cycles(table->key_type, seen) || _can_have_cycles(table->value_type, seen); } + case SetType: return _can_have_cycles(Match(t, SetType)->item_type, seen); case StructType: { for (arg_t *field = Match(t, StructType)->fields; field; field = field->next) { if (_can_have_cycles(field->type, seen)) @@ -363,6 +374,7 @@ type_t *replace_type(type_t *t, type_t *target, type_t *replacement) #define REPLACED_MEMBER(t, tag, member) ({ t = memcpy(GC_MALLOC(sizeof(type_t)), (t), sizeof(type_t)); Match((struct type_s*)(t), tag)->member = replace_type(Match((t), tag)->member, target, replacement); t; }) switch (t->tag) { case ArrayType: return REPLACED_MEMBER(t, ArrayType, item_type); + case SetType: return REPLACED_MEMBER(t, SetType, item_type); case TableType: { t = REPLACED_MEMBER(t, TableType, key_type); t = REPLACED_MEMBER(t, TableType, value_type); @@ -407,6 +419,7 @@ size_t type_size(type_t *t) case NumType: return Match(t, NumType)->bits/8; case TextType: return sizeof(CORD); case ArrayType: return sizeof(array_t); + case SetType: return sizeof(table_t); case TableType: return sizeof(table_t); case FunctionType: return sizeof(void*); case ClosureType: return sizeof(struct {void *fn, *userdata;}); @@ -458,6 +471,7 @@ size_t type_align(type_t *t) case IntType: return Match(t, IntType)->bits/8; case NumType: return Match(t, NumType)->bits/8; case TextType: return __alignof__(CORD); + case SetType: return __alignof__(table_t); case ArrayType: return __alignof__(array_t); case TableType: return __alignof__(table_t); case FunctionType: return __alignof__(void*); @@ -509,6 +523,13 @@ type_t *get_field_type(type_t *t, const char *field_name) } return NULL; } + case SetType: { + if (streq(field_name, "items")) + return Type(ArrayType, .item_type=Match(t, SetType)->item_type); + else if (streq(field_name, "fallback")) + return Type(PointerType, .pointed=t, .is_readonly=true, .is_optional=true); + return NULL; + } case TableType: { if (streq(field_name, "keys")) return Type(ArrayType, Match(t, TableType)->key_type); @@ -48,6 +48,7 @@ struct type_s { CStringType, TextType, ArrayType, + SetType, TableType, FunctionType, ClosureType, @@ -79,6 +80,9 @@ struct type_s { type_t *item_type; } ArrayType; struct { + type_t *item_type; + } SetType; + struct { type_t *key_type, *value_type; } TableType; struct { |
