aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2024-08-10 15:15:38 -0400
committerBruce Hill <bruce@bruce-hill.com>2024-08-10 15:15:38 -0400
commit8d3d5913129a8ede381462d5ad5e98f9c789e5c8 (patch)
tree074e1fd4489710af0810e2a901106a7161467021
parentcb6cebf12e2124503f0551bc1bf6b44f68d86746 (diff)
Add Sets to the language
-rw-r--r--ast.c5
-rw-r--r--ast.h10
-rw-r--r--builtins/table.c109
-rw-r--r--builtins/table.h30
-rw-r--r--builtins/types.c4
-rw-r--r--builtins/types.h2
-rw-r--r--compile.c182
-rw-r--r--environment.c10
-rw-r--r--parse.c61
-rw-r--r--test/sets.tm36
-rw-r--r--test/tables.tm24
-rw-r--r--typecheck.c57
-rw-r--r--types.c21
-rw-r--r--types.h4
14 files changed, 534 insertions, 21 deletions
diff --git a/ast.c b/ast.c
index 73556e49..8f804c5f 100644
--- a/ast.c
+++ b/ast.c
@@ -114,6 +114,9 @@ CORD ast_to_xml(ast_t *ast)
T(Min, "<Min>%r%r%r</Min>", ast_to_xml(data.lhs), ast_to_xml(data.rhs), optional_tagged("key", data.key))
T(Max, "<Max>%r%r%r</Max>", ast_to_xml(data.lhs), ast_to_xml(data.rhs), optional_tagged("key", data.key))
T(Array, "<Array>%r%r</Array>", optional_tagged_type("item-type", data.type), ast_list_to_xml(data.items))
+ T(Set, "<Set>%r%r</Set>",
+ optional_tagged_type("item-type", data.item_type),
+ ast_list_to_xml(data.items))
T(Table, "<Table>%r%r%r%r%r</Table>",
optional_tagged_type("key-type", data.key_type), optional_tagged_type("value-type", data.value_type),
ast_list_to_xml(data.entries), optional_tagged("fallback", data.fallback),
@@ -169,6 +172,7 @@ CORD type_ast_to_xml(type_ast_t *t)
T(PointerTypeAST, "<PointerType is_optional=\"%s\" is_stack=\"%s\" is_readonly=\"%s\">%r</PointerType>",
data.is_optional ? "yes" : "no", data.is_stack ? "yes" : "no", data.is_readonly ? "yes" : "no", type_ast_to_xml(data.pointed))
T(ArrayTypeAST, "<ArrayType>%r</ArrayType>", type_ast_to_xml(data.item))
+ T(SetTypeAST, "<TableType>%r</TableType>", type_ast_to_xml(data.item))
T(TableTypeAST, "<TableType>%r %r</TableType>", type_ast_to_xml(data.key), type_ast_to_xml(data.value))
T(FunctionTypeAST, "<FunctionType>%r %r</FunctionType>", arg_list_to_xml(data.args), type_ast_to_xml(data.ret))
#undef T
@@ -220,6 +224,7 @@ bool type_ast_eq(type_ast_t *x, type_ast_t *y)
&& type_ast_eq(x_info->pointed, y_info->pointed));
}
case ArrayTypeAST: return type_ast_eq(Match(x, ArrayTypeAST)->item, Match(y, ArrayTypeAST)->item);
+ case SetTypeAST: return type_ast_eq(Match(x, SetTypeAST)->item, Match(y, SetTypeAST)->item);
case TableTypeAST: {
auto tx = Match(x, TableTypeAST);
auto ty = Match(y, TableTypeAST);
diff --git a/ast.h b/ast.h
index 976a47dc..c049e17a 100644
--- a/ast.h
+++ b/ast.h
@@ -57,6 +57,7 @@ typedef enum {
VarTypeAST,
PointerTypeAST,
ArrayTypeAST,
+ SetTypeAST,
TableTypeAST,
FunctionTypeAST,
} type_ast_e;
@@ -89,6 +90,9 @@ struct type_ast_s {
type_ast_t *key, *value;
} TableTypeAST;
struct {
+ type_ast_t *item;
+ } SetTypeAST;
+ struct {
arg_ast_t *args;
type_ast_t *ret;
} FunctionTypeAST;
@@ -104,7 +108,7 @@ typedef enum {
BinaryOp, UpdateAssign,
Length, Not, Negative, HeapAllocate, StackReference,
Min, Max,
- Array, Table, TableEntry, Comprehension,
+ Array, Set, Table, TableEntry, Comprehension,
FunctionDef, Lambda,
FunctionCall, MethodCall,
Block,
@@ -178,6 +182,10 @@ struct ast_s {
ast_list_t *items;
} Array;
struct {
+ type_ast_t *item_type;
+ ast_list_t *items;
+ } Set;
+ struct {
type_ast_t *key_type, *value_type;
ast_t *fallback, *default_value;
ast_list_t *entries;
diff --git a/builtins/table.c b/builtins/table.c
index 16b06f6f..5329ec24 100644
--- a/builtins/table.c
+++ b/builtins/table.c
@@ -472,8 +472,12 @@ public CORD Table$as_text(const table_t *t, bool colorize, const TypeInfo *type)
assert(type->tag == TableInfo);
auto table = type->TableInfo;
- if (!t)
- return CORD_all("{", generic_as_text(NULL, false, table.key), ":", generic_as_text(NULL, false, table.value), "}");
+ if (!t) {
+ if (table.value != &$Void)
+ return CORD_all("{", generic_as_text(NULL, false, table.key), ":", generic_as_text(NULL, false, table.value), "}");
+ else
+ return CORD_all("{", generic_as_text(NULL, false, table.key), "}");
+ }
int64_t val_off = value_offset(type);
CORD c = "{";
@@ -482,8 +486,8 @@ public CORD Table$as_text(const table_t *t, bool colorize, const TypeInfo *type)
c = CORD_cat(c, ", ");
void *entry = GET_ENTRY(*t, i);
c = CORD_cat(c, generic_as_text(entry, colorize, table.key));
- c = CORD_cat(c, ":");
- c = CORD_cat(c, generic_as_text(entry + val_off, colorize, table.value));
+ if (table.value != &$Void)
+ c = CORD_all(c, ":", generic_as_text(entry + val_off, colorize, table.value));
}
if (t->fallback) {
@@ -522,6 +526,103 @@ public table_t Table$from_entries(array_t entries, const TypeInfo *type)
return t;
}
+// Overlap is "set intersection" in formal terms
+public table_t Table$overlap(table_t a, table_t b, const TypeInfo *type)
+{
+ // Return a table such that t[k]==a[k] for all k such that a:has(k), b:has(k), and a[k]==b[k]
+ table_t result = {};
+ const size_t offset = value_offset(type);
+ for (int64_t i = 0; i < Table$length(a); i++) {
+ void *key = GET_ENTRY(a, i);
+ void *a_value = key + offset;
+ void *b_value = Table$get(b, key, type);
+ if (b_value && generic_equal(a_value, b_value, type->TableInfo.value))
+ Table$set(&result, key, a_value, type);
+ }
+
+ if (a.fallback) {
+ result.fallback = new(table_t);
+ *result.fallback = Table$overlap(*a.fallback, b, type);
+ }
+
+ if (a.default_value && b.default_value && generic_equal(a.default_value, b.default_value, type->TableInfo.value))
+ result.default_value = a.default_value;
+
+ return result;
+}
+
+// With is "set union" in formal terms
+public table_t Table$with(table_t a, table_t b, const TypeInfo *type)
+{
+ // return a table such that t[k]==b[k] for all k such that b:has(k), and t[k]==a[k] for all k such that a:has(k) and not b:has(k)
+ table_t result = {};
+ const size_t offset = value_offset(type);
+ for (int64_t i = 0; i < Table$length(a); i++) {
+ void *key = GET_ENTRY(a, i);
+ Table$set(&result, key, key + offset, type);
+ }
+ for (int64_t i = 0; i < Table$length(b); i++) {
+ void *key = GET_ENTRY(b, i);
+ Table$set(&result, key, key + offset, type);
+ }
+
+ if (a.fallback && b.fallback) {
+ result.fallback = new(table_t);
+ *result.fallback = Table$with(*a.fallback, *b.fallback, type);
+ } else {
+ result.fallback = a.fallback ? a.fallback : b.fallback;
+ }
+
+ // B's default value takes precedence over A's
+ result.default_value = b.default_value ? b.default_value : a.default_value;
+
+ return result;
+}
+
+// Without is "set difference" in formal terms
+public table_t Table$without(table_t a, table_t b, const TypeInfo *type)
+{
+ // Return a table such that t[k]==a[k] for all k such that not b:has(k) or b[k] != a[k]
+ table_t result = {};
+ const size_t offset = value_offset(type);
+ for (int64_t i = 0; i < Table$length(a); i++) {
+ void *key = GET_ENTRY(a, i);
+ void *a_value = key + offset;
+ void *b_value = Table$get(b, key, type);
+ if (!b_value || !generic_equal(a_value, b_value, type->TableInfo.value))
+ Table$set(&result, key, a_value, type);
+ }
+
+ if (a.fallback) {
+ result.fallback = new(table_t);
+ *result.fallback = Table$without(*a.fallback, b, type);
+ }
+
+ if (a.default_value) {
+ if (!b.default_value || !generic_equal(a.default_value, b.default_value, type->TableInfo.value))
+ result.default_value = a.default_value;
+ }
+
+ return result;
+}
+
+public bool Table$is_subset_of(table_t a, table_t b, bool strict, const TypeInfo *type)
+{
+ if (a.entries.length > b.entries.length || (strict && a.entries.length == b.entries.length))
+ return false;
+
+ for (int64_t i = 0; i < Table$length(a); i++) {
+ void *found = Table$get_raw(b, GET_ENTRY(a, i), type);
+ if (!found) return false;
+ }
+ return true;
+}
+
+public bool Table$is_superset_of(table_t a, table_t b, bool strict, const TypeInfo *type)
+{
+ return Table$is_subset_of(b, a, strict, type);
+}
+
public void *Table$str_get(table_t t, const char *key)
{
void **ret = Table$get(t, &key, &StrToVoidStarTable);
diff --git a/builtins/table.h b/builtins/table.h
index 36f6d75e..3b1c7f98 100644
--- a/builtins/table.h
+++ b/builtins/table.h
@@ -21,15 +21,30 @@
table.fallback = fb; \
table.default_value = def; \
table; })
-#define Table_get(table_expr, key_t, val_t, key_expr, info_expr, filename, start, end) ({ \
+#define Set(item_t, item_info, N, ...) ({ \
+ item_t ents[N] = {__VA_ARGS__}; \
+ table_t set = Table$from_entries((array_t){ \
+ .data=memcpy(GC_MALLOC(sizeof(ents)), ents, sizeof(ents)), \
+ .length=sizeof(ents)/sizeof(ents[0]), \
+ .stride=(void*)&ents[1] - (void*)&ents[0], \
+ }, $SetInfo(item_info)); \
+ set; })
+
+table_t Table$from_entries(array_t entries, const TypeInfo *type);
+void *Table$get(table_t t, const void *key, const TypeInfo *type);
+#define Table$get_value_or_fail(table_expr, key_t, val_t, key_expr, info_expr, filename, start, end) ({ \
const table_t t = table_expr; key_t k = key_expr; const TypeInfo* info = info_expr; \
const val_t *v = Table$get(t, &k, info); \
if (__builtin_expect(v == NULL, 0)) \
fail_source(filename, start, end, "The key %r is not in this table\n", generic_as_text(&k, no, info->TableInfo.key)); \
*v; })
-
-table_t Table$from_entries(array_t entries, const TypeInfo *type);
-void *Table$get(table_t t, const void *key, const TypeInfo *type);
+#define Table$get_value_or_default(table_expr, key_t, val_t, key_expr, default_val, info_expr) ({ \
+ const table_t t = table_expr; const key_t k = key_expr; \
+ const val_t *v = Table$get(t, &k, info_expr); \
+ v ? *v : default_val; })
+#define Table$has_value(table_expr, key_expr, info_expr) ({ \
+ const table_t t = table_expr; __typeof(key_expr) k = key_expr; \
+ (Table$get(t, &k, info_expr) != NULL); })
void *Table$get_raw(table_t t, const void *key, const TypeInfo *type);
void *Table$entry(table_t t, int64_t n);
void *Table$reserve(table_t *t, const void *key, const void *value, const TypeInfo *type);
@@ -39,6 +54,13 @@ void Table$set(table_t *t, const void *key, const void *value, const TypeInfo *t
#define Table$reserve_value(t, key_expr, type) ({ __typeof(key_expr) k = key_expr; Table$reserve(t, &k, NULL, type); })
void Table$remove(table_t *t, const void *key, const TypeInfo *type);
#define Table$remove_value(t, key_expr, type) ({ __typeof(key_expr) k = key_expr; Table$remove(t, &k, type); })
+
+table_t Table$overlap(table_t a, table_t b, const TypeInfo *type);
+table_t Table$with(table_t a, table_t b, const TypeInfo *type);
+table_t Table$without(table_t a, table_t b, const TypeInfo *type);
+bool Table$is_subset_of(table_t a, table_t b, bool strict, const TypeInfo *type);
+bool Table$is_superset_of(table_t a, table_t b, bool strict, const TypeInfo *type);
+
void Table$clear(table_t *t);
table_t Table$sorted(table_t t, const TypeInfo *type);
void Table$mark_copy_on_write(table_t *t);
diff --git a/builtins/types.c b/builtins/types.c
index 355473e9..4fb2c523 100644
--- a/builtins/types.c
+++ b/builtins/types.c
@@ -29,8 +29,8 @@ public const TypeInfo $TypeInfo = {
.TypeInfoInfo.type_str="TypeInfo",
};
-public const TypeInfo $Void = {.size=0, .align=0};
-public const TypeInfo $Abort = {.size=0, .align=0};
+public const TypeInfo $Void = {.size=0, .align=0, .tag=EmptyStruct};
+public const TypeInfo $Abort = {.size=0, .align=0, .tag=EmptyStruct};
public CORD Func$as_text(const void *fn, bool colorize, const TypeInfo *type)
{
diff --git a/builtins/types.h b/builtins/types.h
index 864a55c2..4184e8fa 100644
--- a/builtins/types.h
+++ b/builtins/types.h
@@ -58,6 +58,8 @@ typedef struct TypeInfo {
.tag=PointerInfo, .PointerInfo={.sigil=sigil_expr, .pointed=pointed_info, .is_optional=opt}})
#define $ArrayInfo(item_info) &((TypeInfo){.size=sizeof(array_t), .align=__alignof__(array_t), \
.tag=ArrayInfo, .ArrayInfo.item=item_info})
+#define $SetInfo(item_info) &((TypeInfo){.size=sizeof(table_t), .align=__alignof__(table_t), \
+ .tag=TableInfo, .TableInfo.key=item_info, .TableInfo.value=&$Void})
#define $TableInfo(key_expr, value_expr) &((TypeInfo){.size=sizeof(table_t), .align=__alignof__(table_t), \
.tag=TableInfo, .TableInfo.key=key_expr, .TableInfo.value=value_expr})
#define $FunctionInfo(typestr) &((TypeInfo){.size=sizeof(void*), .align=__alignof__(void*), \
diff --git a/compile.c b/compile.c
index 3a632afe..14a72a29 100644
--- a/compile.c
+++ b/compile.c
@@ -55,6 +55,13 @@ static bool promote(env_t *env, CORD *code, type_t *actual, type_t *needed)
return true;
}
+ // Set -> Array promotion:
+ if (needed->tag == ArrayType && actual->tag == SetType
+ && type_eq(Match(needed, ArrayType)->item_type, Match(actual, SetType)->item_type)) {
+ *code = CORD_all("(", *code, ").entries");
+ return true;
+ }
+
return false;
}
@@ -126,6 +133,7 @@ CORD compile_type(type_t *t)
return text->lang ? CORD_all(namespace_prefix(text->env->libname, text->env->namespace->parent), text->lang, "_t") : "Text_t";
}
case ArrayType: return "array_t";
+ case SetType: return "table_t";
case TableType: return "table_t";
case FunctionType: {
auto fn = Match(t, FunctionType);
@@ -877,7 +885,7 @@ CORD compile_statement(env_t *env, ast_t *ast)
}
CORD loop = CORD_all("ARRAY_INCREF(", array_code, ");\n",
for_code, "{\n",
- compile_type(item_t), " ", value,
+ compile_declaration(item_t, value),
" = *(", compile_type(item_t), "*)(", array_code, ".data + (",index,"-1)*", array_code, ".stride);\n",
body, "\n}");
@@ -888,6 +896,28 @@ CORD compile_statement(env_t *env, ast_t *ast)
loop = CORD_all("{\narray_t ",array_code," = ", compile(env, array), ";\n", loop, "\n}");
return loop;
}
+ case SetType: {
+ type_t *item_type = Match(iter_t, SetType)->item_type;
+
+ CORD set = is_idempotent(for_->iter) ? compile(env, for_->iter) : "set";
+ CORD loop = CORD_all("ARRAY_INCREF(", set, ".entries);\n"
+ "for (int64_t i = 0; i < ",set,".entries.length; ++i) {\n");
+
+ if (for_->vars) {
+ if (for_->vars->next)
+ code_err(for_->vars->next->ast, "This is too many variables for this loop");
+ CORD item = compile(env, for_->vars->ast);
+ loop = CORD_all(loop, compile_declaration(item_type, item), " = *(", compile_type(item_type), "*)(",
+ set,".entries.data + i*", set, ".entries.stride);\n");
+ }
+ loop = CORD_all(loop, body, "\n}");
+ if (for_->empty)
+ loop = CORD_all("if (", set, ".entries.length > 0) {\n", loop, "\n} else ", compile_statement(env, for_->empty));
+ loop = CORD_all(loop, stop, "\nARRAY_DECREF(", set, ".entries);\n");
+ if (!is_idempotent(for_->iter))
+ loop = CORD_all("{\ntable_t ",set," = ", compile(env, for_->iter), ";\n", loop, "\n}");
+ return loop;
+ }
case TableType: {
type_t *key_t = Match(iter_t, TableType)->key_type;
type_t *value_t = Match(iter_t, TableType)->value_type;
@@ -910,14 +940,14 @@ CORD compile_statement(env_t *env, ast_t *ast)
}
if (key) {
- loop = CORD_all(loop, compile_type(key_t), " ", key, " = *(", compile_type(key_t), "*)(",
+ loop = CORD_all(loop, compile_declaration(key_t, key), " = *(", compile_type(key_t), "*)(",
table,".entries.data + i*", table, ".entries.stride);\n");
}
if (value) {
size_t value_offset = type_size(key_t);
if (type_align(value_t) > 1 && value_offset % type_align(value_t))
value_offset += type_align(value_t) - (value_offset % type_align(value_t)); // padding
- loop = CORD_all(loop, compile_type(value_t), " ", value, " = *(", compile_type(value_t), "*)(",
+ loop = CORD_all(loop, compile_declaration(value_t, value), " = *(", compile_type(value_t), "*)(",
table,".entries.data + i*", table, ".entries.stride + ", heap_strf("%zu", value_offset), ");\n");
}
loop = CORD_all(loop, body, "\n}");
@@ -1081,6 +1111,7 @@ CORD expr_as_text(env_t *env, CORD expr, type_t *t, CORD color)
return CORD_asprintf("Text$as_text(stack(%r), %r, %r)", expr, color, compile_type_info(env, t));
}
case ArrayType: return CORD_asprintf("Array$as_text(stack(%r), %r, %r)", expr, color, compile_type_info(env, t));
+ case SetType: return CORD_asprintf("Table$as_text(stack(%r), %r, %r)", expr, color, compile_type_info(env, t));
case TableType: return CORD_asprintf("Table$as_text(stack(%r), %r, %r)", expr, color, compile_type_info(env, t));
case FunctionType: case ClosureType: return CORD_asprintf("Func$as_text(stack(%r), %r, %r)", expr, color, compile_type_info(env, t));
case PointerType: return CORD_asprintf("Pointer$as_text(stack(%r), %r, %r)", expr, color, compile_type_info(env, t));
@@ -1773,6 +1804,58 @@ CORD compile(env_t *env, ast_t *ast)
}
}
+ case Set: {
+ auto set = Match(ast, Set);
+ if (!set->items)
+ return "((table_t){})";
+
+ type_t *set_type = get_type(env, ast);
+ type_t *item_type = Match(set_type, SetType)->item_type;
+
+ for (ast_list_t *item = set->items; item; item = item->next) {
+ if (item->ast->tag == Comprehension)
+ goto set_comprehension;
+ }
+
+ { // No comprehension:
+ CORD code = CORD_all("Set(",
+ compile_type(item_type), ", ",
+ compile_type_info(env, item_type));
+
+ size_t n = 0;
+ for (ast_list_t *item = set->items; item; item = item->next)
+ ++n;
+ CORD_appendf(&code, ", %zu", n);
+
+ for (ast_list_t *item = set->items; item; item = item->next) {
+ code = CORD_all(code, ",\n\t", compile(env, item->ast));
+ }
+ return CORD_cat(code, ")");
+ }
+
+ set_comprehension:
+ {
+ static int64_t comp_num = 1;
+ env_t *scope = fresh_scope(env);
+ scope->comprehension_var = heap_strf("set$%ld", comp_num++);
+
+ CORD code = CORD_all("({ table_t ", scope->comprehension_var, " = {};");
+ set_binding(scope, scope->comprehension_var, new(binding_t, .type=set_type, .code=scope->comprehension_var));
+ for (ast_list_t *item = set->items; item; item = item->next) {
+ if (item->ast->tag == Comprehension) {
+ code = CORD_all(code, "\n", compile_statement(scope, item->ast));
+ } else {
+ CORD add_item = compile_statement(
+ scope, WrapAST(item->ast, MethodCall, .name="add", .self=FakeAST(StackReference, FakeAST(Var, scope->comprehension_var)),
+ .args=new(arg_ast_t, .value=item->ast)));
+ code = CORD_all(code, "\n", add_item);
+ }
+ }
+ code = CORD_all(code, " ", scope->comprehension_var, "; })");
+ return code;
+ }
+
+ }
case Comprehension: {
ast_t *base = Match(ast, Comprehension)->expr;
while (base->tag == Comprehension)
@@ -1961,12 +2044,81 @@ CORD compile(env_t *env, ast_t *ast)
return CORD_all("Array$reversed(", self, ", ", padded_item_size, ")");
} else code_err(ast, "There is no '%s' method for arrays", call->name);
}
+ case SetType: {
+ auto set = Match(self_value_t, SetType);
+ if (streq(call->name, "has")) {
+ CORD self = compile_to_pointer_depth(env, call->self, 0, false);
+ arg_t *arg_spec = new(arg_t, .name="key", .type=set->item_type);
+ return CORD_all("Table$has_value(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ",
+ compile_type_info(env, self_value_t), ")");
+ } else if (streq(call->name, "add")) {
+ CORD self = compile_to_pointer_depth(env, call->self, 1, false);
+ arg_t *arg_spec = new(arg_t, .name="item", .type=set->item_type);
+ return CORD_all("Table$set_value(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", NULL, ",
+ compile_type_info(env, self_value_t), ")");
+ } else if (streq(call->name, "add_all")) {
+ arg_t *arg_spec = new(arg_t, .name="items", .type=Type(ArrayType, .item_type=Match(self_value_t, SetType)->item_type));
+ return CORD_all("({ table_t *set = ", compile_to_pointer_depth(env, call->self, 1, false), "; ",
+ "array_t to_add = ", compile_arguments(env, ast, arg_spec, call->args), "; ",
+ "for (int64_t i = 0; i < to_add.length; i++)\n"
+ "Table$set(set, to_add.data + i*to_add.stride, NULL, ", compile_type_info(env, self_value_t), ");\n",
+ "(void)0; })");
+ } else if (streq(call->name, "remove")) {
+ CORD self = compile_to_pointer_depth(env, call->self, 1, false);
+ arg_t *arg_spec = new(arg_t, .name="item", .type=set->item_type);
+ return CORD_all("Table$remove_value(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ",
+ compile_type_info(env, self_value_t), ")");
+ } else if (streq(call->name, "remove_all")) {
+ arg_t *arg_spec = new(arg_t, .name="items", .type=Type(ArrayType, .item_type=Match(self_value_t, SetType)->item_type));
+ return CORD_all("({ table_t *set = ", compile_to_pointer_depth(env, call->self, 1, false), "; ",
+ "array_t to_add = ", compile_arguments(env, ast, arg_spec, call->args), "; ",
+ "for (int64_t i = 0; i < to_add.length; i++)\n"
+ "Table$remove(set, to_add.data + i*to_add.stride, ", compile_type_info(env, self_value_t), ");\n",
+ "(void)0; })");
+ } else if (streq(call->name, "clear")) {
+ CORD self = compile_to_pointer_depth(env, call->self, 1, false);
+ (void)compile_arguments(env, ast, NULL, call->args);
+ return CORD_all("Table$clear(", self, ")");
+ } else if (streq(call->name, "with")) {
+ CORD self = compile_to_pointer_depth(env, call->self, 0, false);
+ arg_t *arg_spec = new(arg_t, .name="other", .type=self_value_t);
+ return CORD_all("Table$with(", self, ", ", compile_arguments(env, ast, arg_spec, call->args),
+ ", ", compile_type_info(env, self_value_t), ")");
+ } else if (streq(call->name, "overlap")) {
+ CORD self = compile_to_pointer_depth(env, call->self, 0, false);
+ arg_t *arg_spec = new(arg_t, .name="other", .type=self_value_t);
+ return CORD_all("Table$overlap(", self, ", ", compile_arguments(env, ast, arg_spec, call->args),
+ ", ", compile_type_info(env, self_value_t), ")");
+ } else if (streq(call->name, "without")) {
+ CORD self = compile_to_pointer_depth(env, call->self, 0, false);
+ arg_t *arg_spec = new(arg_t, .name="other", .type=self_value_t);
+ return CORD_all("Table$without(", self, ", ", compile_arguments(env, ast, arg_spec, call->args),
+ ", ", compile_type_info(env, self_value_t), ")");
+ } else if (streq(call->name, "is_subset_of")) {
+ CORD self = compile_to_pointer_depth(env, call->self, 0, false);
+ arg_t *arg_spec = new(arg_t, .name="other", .type=self_value_t,
+ .next=new(arg_t, .name="strict", .type=Type(BoolType), .default_val=FakeAST(Bool, false)));
+ return CORD_all("Table$is_subset_of(", self, ", ", compile_arguments(env, ast, arg_spec, call->args),
+ ", ", compile_type_info(env, self_value_t), ")");
+ } else if (streq(call->name, "is_superset_of")) {
+ CORD self = compile_to_pointer_depth(env, call->self, 0, false);
+ arg_t *arg_spec = new(arg_t, .name="other", .type=self_value_t,
+ .next=new(arg_t, .name="strict", .type=Type(BoolType), .default_val=FakeAST(Bool, false)));
+ return CORD_all("Table$is_superset_of(", self, ", ", compile_arguments(env, ast, arg_spec, call->args),
+ ", ", compile_type_info(env, self_value_t), ")");
+ } else code_err(ast, "There is no '%s' method for tables", call->name);
+ }
case TableType: {
auto table = Match(self_value_t, TableType);
if (streq(call->name, "get")) {
CORD self = compile_to_pointer_depth(env, call->self, 0, false);
- arg_t *arg_spec = new(arg_t, .name="key", .type=Type(PointerType, .pointed=table->key_type, .is_stack=true, .is_readonly=true));
- return CORD_all("Table$get(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ",
+ arg_t *arg_spec = new(arg_t, .name="key", .type=table->key_type, .next=new(arg_t, .name="default", .type=table->value_type));
+ return CORD_all("Table$get_value_or_default(", self, ", ", compile_type(table->key_type), ", ", compile_type(table->value_type), ", ",
+ compile_arguments(env, ast, arg_spec, call->args), ", ", compile_type_info(env, self_value_t), ")");
+ } else if (streq(call->name, "has")) {
+ CORD self = compile_to_pointer_depth(env, call->self, 0, false);
+ arg_t *arg_spec = new(arg_t, .name="key", .type=table->key_type);
+ return CORD_all("Table$has_value(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ",
compile_type_info(env, self_value_t), ")");
} else if (streq(call->name, "set")) {
CORD self = compile_to_pointer_depth(env, call->self, 1, false);
@@ -2147,6 +2299,16 @@ CORD compile(env_t *env, ast_t *ast)
}
code_err(ast, "The field '%s' is not a valid field name of %T", f->field, value_t);
}
+ case SetType: {
+ if (streq(f->field, "items")) {
+ return CORD_all("({ table_t *t = ", compile_to_pointer_depth(env, f->fielded, 1, false), ";\n"
+ "ARRAY_INCREF(t->entries);\n"
+ "t->entries; })");
+ } else if (streq(f->field, "fallback")) {
+ return CORD_all("(", compile_to_pointer_depth(env, f->fielded, 0, false), ").fallback");
+ }
+ code_err(ast, "There is no '%s' field on sets", f->field);
+ }
case TableType: {
if (streq(f->field, "keys")) {
return CORD_all("({ table_t *t = ", compile_to_pointer_depth(env, f->fielded, 1, false), ";\n"
@@ -2219,7 +2381,7 @@ CORD compile(env_t *env, ast_t *ast)
if (!promote(env, &key, index_t, key_t))
code_err(indexing->index, "This value has type %T, but this table can only be index with keys of type %T", index_t, key_t);
file_t *f = indexing->index->file;
- return CORD_all("Table_get(", table, ", ", compile_type(key_t), ", ", compile_type(value_t), ", ",
+ return CORD_all("Table$get_value_or_fail(", table, ", ", compile_type(key_t), ", ", compile_type(value_t), ", ",
key, ", ", compile_type_info(env, container_t), ", ",
Text$quoted(f->filename, false), ", ", CORD_asprintf("%ld", (int64_t)(indexing->index->start - f->text)), ", ",
CORD_asprintf("%ld", (int64_t)(indexing->index->end - f->text)),
@@ -2309,12 +2471,16 @@ CORD compile_type_info(env_t *env, type_t *t)
}
case ArrayType: {
type_t *item_t = Match(t, ArrayType)->item_type;
- return CORD_asprintf("$ArrayInfo(%r)", compile_type_info(env, item_t));
+ return CORD_all("$ArrayInfo(", compile_type_info(env, item_t), ")");
+ }
+ case SetType: {
+ type_t *item_type = Match(t, SetType)->item_type;
+ return CORD_all("$SetInfo(", compile_type_info(env, item_type), ")");
}
case TableType: {
type_t *key_type = Match(t, TableType)->key_type;
type_t *value_type = Match(t, TableType)->value_type;
- return CORD_asprintf("$TableInfo(%r, %r)", compile_type_info(env, key_type), compile_type_info(env, value_type));
+ return CORD_all("$TableInfo(", compile_type_info(env, key_type), ", ", compile_type_info(env, value_type), ")");
}
case PointerType: {
auto ptr = Match(t, PointerType);
diff --git a/environment.c b/environment.c
index 69ad9b66..84f6880e 100644
--- a/environment.c
+++ b/environment.c
@@ -333,6 +333,16 @@ env_t *for_scope(env_t *env, ast_t *ast)
}
return scope;
}
+ case SetType: {
+ if (for_->vars) {
+ if (for_->vars->next)
+ code_err(for_->vars->next->ast, "This is too many variables for this loop");
+ type_t *item_type = Match(iter_t, SetType)->item_type;
+ const char *name = Match(for_->vars->ast, Var)->name;
+ set_binding(scope, name, new(binding_t, .type=item_type, .code=CORD_cat("$", name)));
+ }
+ return scope;
+ }
case TableType: {
const char *vars[2] = {};
int64_t num_vars = 0;
diff --git a/parse.c b/parse.c
index 128a2aca..2b7b5816 100644
--- a/parse.c
+++ b/parse.c
@@ -513,6 +513,19 @@ type_ast_t *parse_table_type(parse_ctx_t *ctx, const char *pos) {
return NewTypeAST(ctx->file, start, pos, TableTypeAST, .key=key_type, .value=value_type);
}
+type_ast_t *parse_set_type(parse_ctx_t *ctx, const char *pos) {
+ const char *start = pos;
+ if (!match(&pos, "{")) return NULL;
+ whitespace(&pos);
+ type_ast_t *item_type = parse_type(ctx, pos);
+ if (!item_type) return NULL;
+ pos = item_type->end;
+ whitespace(&pos);
+ if (match(&pos, ":")) return NULL;
+ expect_closing(ctx, &pos, "}", "I wasn't able to parse the rest of this set type");
+ return NewTypeAST(ctx->file, start, pos, SetTypeAST, .item=item_type);
+}
+
type_ast_t *parse_func_type(parse_ctx_t *ctx, const char *pos) {
const char *start = pos;
if (!match_word(&pos, "func")) return NULL;
@@ -576,6 +589,7 @@ type_ast_t *parse_type(parse_ctx_t *ctx, const char *pos) {
|| (type=parse_pointer_type(ctx, pos))
|| (type=parse_array_type(ctx, pos))
|| (type=parse_table_type(ctx, pos))
+ || (type=parse_set_type(ctx, pos))
|| (type=parse_type_name(ctx, pos))
|| (type=parse_func_type(ctx, pos))
);
@@ -698,7 +712,7 @@ PARSER(parse_table) {
key_type = expect(ctx, pos-1, &pos, parse_type, "I couldn't parse a key type for this table");
whitespace(&pos);
if (!match(&pos, ":"))
- parser_err(ctx, pos, pos, "I expected an ':' for this table type");
+ return NULL;
value_type = expect(ctx, pos-1, &pos, parse_type, "I couldn't parse a value type for this table");
whitespace(&pos);
}
@@ -760,6 +774,50 @@ PARSER(parse_table) {
return NewAST(ctx->file, start, pos, Table, .key_type=key_type, .value_type=value_type, .entries=entries, .fallback=fallback, .default_value=default_val);
}
+PARSER(parse_set) {
+ const char *start = pos;
+ if (!match(&pos, "{")) return NULL;
+
+ whitespace(&pos);
+
+ ast_list_t *items = NULL;
+ type_ast_t *item_type = NULL;
+ if (match(&pos, ":")) {
+ whitespace(&pos);
+ item_type = expect(ctx, pos-1, &pos, parse_type, "I couldn't parse a key type for this set");
+ whitespace(&pos);
+ if (match(&pos, ":"))
+ return NULL;
+ whitespace(&pos);
+ }
+
+ for (;;) {
+ ast_t *item = optional(ctx, &pos, parse_extended_expr);
+ if (!item) break;
+ whitespace(&pos);
+ if (match(&pos, ":")) return NULL;
+ ast_t *suffixed = parse_comprehension_suffix(ctx, item);
+ while (suffixed) {
+ item = suffixed;
+ pos = suffixed->end;
+ suffixed = parse_comprehension_suffix(ctx, item);
+ }
+ items = new(ast_list_t, .ast=item, .next=items);
+ if (!match_separator(&pos))
+ break;
+ }
+
+ REVERSE_LIST(items);
+
+ if (!item_type && !items)
+ return NULL;
+
+ whitespace(&pos);
+ expect_closing(ctx, &pos, "}", "I wasn't able to parse the rest of this set");
+
+ return NewAST(ctx->file, start, pos, Set, .item_type=item_type, .items=items);
+}
+
ast_t *parse_field_suffix(parse_ctx_t *ctx, ast_t *lhs) {
if (!lhs) return NULL;
const char *pos = lhs->end;
@@ -1287,6 +1345,7 @@ PARSER(parse_term_no_suffix) {
|| (term=parse_lambda(ctx, pos))
|| (term=parse_parens(ctx, pos))
|| (term=parse_table(ctx, pos))
+ || (term=parse_set(ctx, pos))
|| (term=parse_var(ctx, pos))
|| (term=parse_array(ctx, pos))
|| (term=parse_reduction(ctx, pos))
diff --git a/test/sets.tm b/test/sets.tm
new file mode 100644
index 00000000..bfec068e
--- /dev/null
+++ b/test/sets.tm
@@ -0,0 +1,36 @@
+
+func main():
+ >> t1 := {10, 20, 30, 10}
+ = {10, 20, 30}
+ >> t1:has(10)
+ = yes
+ >> t1:has(-999)
+ = no
+
+ >> t2 := {30, 40}
+
+ >> t1:with(t2)
+ >> {10, 20, 30, 40}
+
+ >> t1:without(t2)
+ >> {10, 20}
+
+ >> t1:overlap(t2)
+ >> {30}
+
+
+ >> {1,2}:is_subset_of({2,3})
+ = no
+ >> {1,2}:is_subset_of({1,2,3})
+ = yes
+ >> {1,2}:is_subset_of({1,2})
+ = yes
+ >> {1,2}:is_subset_of({1,2}, strict=yes)
+ = no
+
+ >> t1:add_all(t2)
+ >> t1
+ = {10, 20, 30, 40}
+ >> t1:remove_all(t2)
+ >> t1
+ = {10, 20}
diff --git a/test/tables.tm b/test/tables.tm
index 132e4301..6dbc1b78 100644
--- a/test/tables.tm
+++ b/test/tables.tm
@@ -61,3 +61,27 @@ func main():
>> t3:remove(3)
>> t3
= {1:10, 2:20}
+
+ do:
+ >> plain := {1:10, 2:20, 3:30}
+ >> plain:get(2, -999)
+ = 20
+ >> plain:get(456, -999)
+ = -999
+ >> plain:has(2)
+ = yes
+ >> plain:has(456)
+ = no
+
+ >> fallback := {4:40; fallback=plain}
+ >> fallback:has(1)
+ = yes
+ >> fallback:get(1, -999)
+ = 10
+
+ >> default := {5:50; default=0}
+ >> default:has(28273)
+ = yes
+ >> default:get(28273, -999)
+ = 0
+
diff --git a/typecheck.c b/typecheck.c
index 10c5820f..ce8b07ed 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -54,6 +54,17 @@ type_t *parse_type_ast(env_t *env, type_ast_t *ast)
padded_type_size(item_t), ARRAY_MAX_STRIDE);
return Type(ArrayType, .item_type=item_t);
}
+ case SetTypeAST: {
+ type_ast_t *item_type = Match(ast, SetTypeAST)->item;
+ type_t *item_t = parse_type_ast(env, item_type);
+ if (!item_t) code_err(item_type, "I can't figure out what this type is.");
+ if (has_stack_memory(item_t))
+ code_err(item_type, "Sets can't have stack references because the array may outlive the stack frame.");
+ if (padded_type_size(item_t) > ARRAY_MAX_STRIDE)
+ code_err(ast, "This set holds items that take up %ld bytes, but the maximum supported size is %ld bytes. Consider using an set of pointers instead.",
+ padded_type_size(item_t), ARRAY_MAX_STRIDE);
+ return Type(SetType, .item_type=item_t);
+ }
case TableTypeAST: {
type_ast_t *key_type_ast = Match(ast, TableTypeAST)->key;
type_t *key_type = parse_type_ast(env, key_type_ast);
@@ -531,6 +542,35 @@ type_t *get_type(env_t *env, ast_t *ast)
code_err(ast, "Arrays cannot hold stack references, because the array may outlive the stack frame the reference was created in.");
return Type(ArrayType, .item_type=item_type);
}
+ case Set: {
+ auto set = Match(ast, Set);
+ type_t *item_type = NULL;
+ if (set->item_type) {
+ item_type = parse_type_ast(env, set->item_type);
+ } else {
+ for (ast_list_t *item = set->items; item; item = item->next) {
+ ast_t *item_ast = item->ast;
+ env_t *scope = env;
+ while (item_ast->tag == Comprehension) {
+ auto comp = Match(item_ast, Comprehension);
+ scope = for_scope(
+ scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars));
+ item_ast = comp->expr;
+ }
+
+ type_t *this_item_type = get_type(scope, item_ast);
+ type_t *item_merged = type_or_type(item_type, this_item_type);
+ if (!item_merged)
+ code_err(item_ast,
+ "This set item has type %T, which is different from earlier set items which have type %T",
+ this_item_type, item_type);
+ item_type = item_merged;
+ }
+ }
+ if (has_stack_memory(item_type))
+ code_err(ast, "Sets cannot hold stack references because the set may outlive the reference's stack frame.");
+ return Type(SetType, .item_type=item_type);
+ }
case Table: {
auto table = Match(ast, Table);
type_t *key_type = NULL, *value_type = NULL;
@@ -682,9 +722,24 @@ type_t *get_type(env_t *env, ast_t *ast)
else if (streq(call->name, "heap_pop")) return Match(self_value_t, ArrayType)->item_type;
else code_err(ast, "There is no '%s' method for arrays", call->name);
}
+ case SetType: {
+ if (streq(call->name, "add")) return Type(VoidType);
+ else if (streq(call->name, "has")) return Type(BoolType);
+ else if (streq(call->name, "add_all")) return Type(VoidType);
+ else if (streq(call->name, "remove")) return Type(VoidType);
+ else if (streq(call->name, "remove_all")) return Type(VoidType);
+ else if (streq(call->name, "clear")) return Type(VoidType);
+ else if (streq(call->name, "with")) return self_value_t;
+ else if (streq(call->name, "overlap")) return self_value_t;
+ else if (streq(call->name, "without")) return self_value_t;
+ else if (streq(call->name, "is_subset_of")) return Type(BoolType);
+ else if (streq(call->name, "is_superset_of")) return Type(BoolType);
+ else code_err(ast, "There is no '%s' method for sets", call->name);
+ }
case TableType: {
auto table = Match(self_value_t, TableType);
- if (streq(call->name, "get")) return Type(PointerType, .pointed=table->value_type, .is_readonly=true, .is_optional=true);
+ if (streq(call->name, "get")) return table->value_type;
+ else if (streq(call->name, "has")) return Type(BoolType);
else if (streq(call->name, "set")) return Type(VoidType);
else if (streq(call->name, "remove")) return Type(VoidType);
else if (streq(call->name, "clear")) return Type(VoidType);
diff --git a/types.c b/types.c
index bb9268aa..5025ee9b 100644
--- a/types.c
+++ b/types.c
@@ -29,6 +29,10 @@ CORD type_to_cord(type_t *t) {
auto table = Match(t, TableType);
return CORD_asprintf("{%r:%r}", type_to_cord(table->key_type), type_to_cord(table->value_type));
}
+ case SetType: {
+ auto set = Match(t, SetType);
+ return CORD_asprintf("{%r}", type_to_cord(set->item_type));
+ }
case ClosureType: {
return type_to_cord(Match(t, ClosureType)->fn);
}
@@ -202,6 +206,7 @@ bool has_heap_memory(type_t *t)
switch (t->tag) {
case ArrayType: return true;
case TableType: return true;
+ case SetType: return true;
case PointerType: return true;
case StructType: {
for (arg_t *field = Match(t, StructType)->fields; field; field = field->next) {
@@ -294,6 +299,11 @@ bool can_promote(type_t *actual, type_t *needed)
&& can_promote(actual_ret, needed_ret)));
}
+ // Set -> Array promotion
+ if (needed->tag == ArrayType && actual->tag == SetType
+ && type_eq(Match(needed, ArrayType)->item_type, Match(actual, SetType)->item_type))
+ return true;
+
return false;
}
@@ -329,6 +339,7 @@ static bool _can_have_cycles(type_t *t, table_t *seen)
auto table = Match(t, TableType);
return _can_have_cycles(table->key_type, seen) || _can_have_cycles(table->value_type, seen);
}
+ case SetType: return _can_have_cycles(Match(t, SetType)->item_type, seen);
case StructType: {
for (arg_t *field = Match(t, StructType)->fields; field; field = field->next) {
if (_can_have_cycles(field->type, seen))
@@ -363,6 +374,7 @@ type_t *replace_type(type_t *t, type_t *target, type_t *replacement)
#define REPLACED_MEMBER(t, tag, member) ({ t = memcpy(GC_MALLOC(sizeof(type_t)), (t), sizeof(type_t)); Match((struct type_s*)(t), tag)->member = replace_type(Match((t), tag)->member, target, replacement); t; })
switch (t->tag) {
case ArrayType: return REPLACED_MEMBER(t, ArrayType, item_type);
+ case SetType: return REPLACED_MEMBER(t, SetType, item_type);
case TableType: {
t = REPLACED_MEMBER(t, TableType, key_type);
t = REPLACED_MEMBER(t, TableType, value_type);
@@ -407,6 +419,7 @@ size_t type_size(type_t *t)
case NumType: return Match(t, NumType)->bits/8;
case TextType: return sizeof(CORD);
case ArrayType: return sizeof(array_t);
+ case SetType: return sizeof(table_t);
case TableType: return sizeof(table_t);
case FunctionType: return sizeof(void*);
case ClosureType: return sizeof(struct {void *fn, *userdata;});
@@ -458,6 +471,7 @@ size_t type_align(type_t *t)
case IntType: return Match(t, IntType)->bits/8;
case NumType: return Match(t, NumType)->bits/8;
case TextType: return __alignof__(CORD);
+ case SetType: return __alignof__(table_t);
case ArrayType: return __alignof__(array_t);
case TableType: return __alignof__(table_t);
case FunctionType: return __alignof__(void*);
@@ -509,6 +523,13 @@ type_t *get_field_type(type_t *t, const char *field_name)
}
return NULL;
}
+ case SetType: {
+ if (streq(field_name, "items"))
+ return Type(ArrayType, .item_type=Match(t, SetType)->item_type);
+ else if (streq(field_name, "fallback"))
+ return Type(PointerType, .pointed=t, .is_readonly=true, .is_optional=true);
+ return NULL;
+ }
case TableType: {
if (streq(field_name, "keys"))
return Type(ArrayType, Match(t, TableType)->key_type);
diff --git a/types.h b/types.h
index 67fd8833..8e4b7976 100644
--- a/types.h
+++ b/types.h
@@ -48,6 +48,7 @@ struct type_s {
CStringType,
TextType,
ArrayType,
+ SetType,
TableType,
FunctionType,
ClosureType,
@@ -79,6 +80,9 @@ struct type_s {
type_t *item_type;
} ArrayType;
struct {
+ type_t *item_type;
+ } SetType;
+ struct {
type_t *key_type, *value_type;
} TableType;
struct {