From 6f3b2c073a968e57d787849dce42ff1253ed0102 Mon Sep 17 00:00:00 2001 From: Bruce Hill Date: Sun, 18 Aug 2024 20:58:36 -0400 Subject: [PATCH] Add `enum.tag` as a way to do a boolean test for whether a value has a particular tag or not --- compile.c | 20 ++++++++++++++++++-- enums.c | 24 ++++++++++++------------ test/enums.tm | 6 ++++++ types.c | 8 ++++++++ 4 files changed, 44 insertions(+), 14 deletions(-) diff --git a/compile.c b/compile.c index fb60e38..17d19f3 100644 --- a/compile.c +++ b/compile.c @@ -270,7 +270,7 @@ CORD compile_statement(env_t *env, ast_t *ast) auto enum_t = Match(subject_t, EnumType); CORD code = CORD_all("{ ", compile_type(subject_t), " subject = ", compile(env, when->subject), ";\n" - "switch (subject.$tag) {"); + "switch (subject.tag) {"); for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { const char *clause_tag_name = Match(clause->tag_name, Var)->name; code = CORD_all(code, "case ", namespace_prefix(enum_t->env->libname, enum_t->env->namespace), "tag$", clause_tag_name, ": {\n"); @@ -1077,7 +1077,7 @@ CORD compile_statement(env_t *env, ast_t *ast) } env_t *enum_env = Match(fn->ret, EnumType)->env; - next_fn = CORD_all("(cur=", next_fn, iter_t->tag == ClosureType ? "(next.userdata)" : "()", ").$tag == ", + next_fn = CORD_all("(cur=", next_fn, iter_t->tag == ClosureType ? "(next.userdata)" : "()", ").tag == ", namespace_prefix(enum_env->libname, enum_env->namespace), "tag$Next"); if (for_->empty) { @@ -2610,6 +2610,22 @@ CORD compile(env_t *env, ast_t *ast) } code_err(ast, "The field '%s' is not a valid field name of %T", f->field, value_t); } + case EnumType: { + auto e = Match(value_t, EnumType); + for (tag_t *tag = e->tags; tag; tag = tag->next) { + if (streq(f->field, tag->name)) { + CORD prefix = namespace_prefix(e->env->libname, e->env->namespace); + if (fielded_t->tag == PointerType) { + CORD fielded = compile_to_pointer_depth(env, f->fielded, 1, false); + return CORD_all("((", fielded, ")->tag == ", prefix, "tag$", tag->name, ")"); + } else { + CORD fielded = compile(env, f->fielded); + return CORD_all("((", fielded, ").tag == ", prefix, "tag$", tag->name, ")"); + } + } + } + code_err(ast, "The field '%s' is not a valid tag name of %T", f->field, value_t); + } case ArrayType: { if (streq(f->field, "length")) return CORD_all("Int64_to_Int((", compile_to_pointer_depth(env, f->fielded, 0, false), ").length)"); diff --git a/enums.c b/enums.c index 10446f5..5174054 100644 --- a/enums.c +++ b/enums.c @@ -26,7 +26,7 @@ static CORD compile_str_method(env_t *env, ast_t *ast) CORD full_name = CORD_cat(namespace_prefix(env->libname, env->namespace), def->name); CORD str_func = CORD_all("static CORD ", full_name, "$as_text(", full_name, "_t *obj, bool use_color) {\n" "\tif (!obj) return \"", def->name, "\";\n" - "switch (obj->$tag) {\n"); + "switch (obj->tag) {\n"); for (tag_ast_t *tag = def->tags; tag; tag = tag->next) { if (!tag->fields) { str_func = CORD_all(str_func, "\tcase ", full_name, "$tag$", tag->name, ": return use_color ? \"\\x1b[36;1m", @@ -63,15 +63,15 @@ static CORD compile_compare_method(env_t *env, ast_t *ast) return CORD_all("static int ", full_name, "$compare(const ", full_name, "_t *x, const ", full_name, "_t *y, const TypeInfo *info) {\n" "(void)info;\n" - "return (x->$tag - y->$tag);\n" + "return (x->tag - y->tag);\n" "}\n"); } CORD cmp_func = CORD_all("static int ", full_name, "$compare(const ", full_name, "_t *x, const ", full_name, "_t *y, const TypeInfo *info) {\n" "(void)info;\n" - "int diff = x->$tag - y->$tag;\n" + "int diff = x->tag - y->tag;\n" "if (diff) return diff;\n" - "switch (x->$tag) {\n"); + "switch (x->tag) {\n"); for (tag_ast_t *tag = def->tags; tag; tag = tag->next) { if (tag->fields) { type_t *tag_type = Table$str_get(*env->types, CORD_to_const_char_star(CORD_all(def->name, "$", tag->name))); @@ -94,14 +94,14 @@ static CORD compile_equals_method(env_t *env, ast_t *ast) return CORD_all("static bool ", full_name, "$equal(const ", full_name, "_t *x, const ", full_name, "_t *y, const TypeInfo *info) {\n" "(void)info;\n" - "return (x->$tag == y->$tag);\n" + "return (x->tag == y->tag);\n" "}\n"); } CORD eq_func = CORD_all("static bool ", full_name, "$equal(const ", full_name, "_t *x, const ", full_name, "_t *y, const TypeInfo *info) {\n" "(void)info;\n" - "if (x->$tag != y->$tag) return no;\n" - "switch (x->$tag) {\n"); + "if (x->tag != y->tag) return no;\n" + "switch (x->tag) {\n"); for (tag_ast_t *tag = def->tags; tag; tag = tag->next) { if (tag->fields) { type_t *tag_type = Table$str_get(*env->types, CORD_to_const_char_star(CORD_all(def->name, "$", tag->name))); @@ -124,14 +124,14 @@ static CORD compile_hash_method(env_t *env, ast_t *ast) return CORD_all("static uint32_t ", full_name, "$hash(const ", full_name, "_t *obj, const TypeInfo *info) {\n" "(void)info;\n" "uint32_t hash;\n" - "halfsiphash(&obj->$tag, sizeof(obj->$tag), TOMO_HASH_KEY, (uint8_t*)&hash, sizeof(hash));\n" + "halfsiphash(&obj->tag, sizeof(obj->tag), TOMO_HASH_KEY, (uint8_t*)&hash, sizeof(hash));\n" "return hash;" "\n}\n"); } CORD hash_func = CORD_all("static uint32_t ", full_name, "$hash(const ", full_name, "_t *obj, const TypeInfo *info) {\n" "(void)info;\n" - "uint32_t hashes[2] = {(uint32_t)obj->$tag, 0};\n" - "switch (obj->$tag) {\n"); + "uint32_t hashes[2] = {(uint32_t)obj->tag, 0};\n" + "switch (obj->tag) {\n"); for (tag_ast_t *tag = def->tags; tag; tag = tag->next) { if (tag->fields) { type_t *tag_type = Table$str_get(*env->types, CORD_to_const_char_star(CORD_all(def->name, "$", tag->name))); @@ -164,7 +164,7 @@ void compile_enum_def(env_t *env, ast_t *ast) } if (arg_sig == CORD_EMPTY) arg_sig = "void"; CORD constructor_impl = CORD_all("public inline ", full_name, "_t ", full_name, "$tagged$", tag->name, "(", arg_sig, ") { return (", - full_name, "_t){.$tag=", full_name, "$tag$", tag->name, ", .$", tag->name, "={"); + full_name, "_t){.tag=", full_name, "$tag$", tag->name, ", .$", tag->name, "={"); for (arg_ast_t *field = tag->fields; field; field = field->next) { constructor_impl = CORD_all(constructor_impl, "$", field->name); if (field->next) constructor_impl = CORD_cat(constructor_impl, ", "); @@ -211,7 +211,7 @@ CORD compile_enum_typedef(env_t *env, ast_t *ast) CORD_appendf(&enum_def, "%r$tag$%s = %ld", full_name, tag->name, tag->value); if (tag->next) enum_def = CORD_cat(enum_def, ", "); } - enum_def = CORD_cat(enum_def, "} $tag;\n" + enum_def = CORD_cat(enum_def, "} tag;\n" "union {\n"); for (tag_ast_t *tag = def->tags; tag; tag = tag->next) { CORD field_def = compile_struct_typedef(env, WrapAST(ast, StructDef, .name=CORD_to_const_char_star(CORD_all(def->name, "$", tag->name)), .fields=tag->fields)); diff --git a/test/enums.tm b/test/enums.tm index b734d48..f4af342 100644 --- a/test/enums.tm +++ b/test/enums.tm @@ -23,6 +23,12 @@ func main(): >> Foo.Two(123, 456) = Foo.Two(x=123, y=456) + >> one := Foo.One(123) + >> one.One + = yes + >> one.Two + = no + >> Foo.One(10) == Foo.One(10) = yes diff --git a/types.c b/types.c index cd38830..8cb534d 100644 --- a/types.c +++ b/types.c @@ -594,6 +594,14 @@ type_t *get_field_type(type_t *t, const char *field_name) } return NULL; } + case EnumType: { + auto e = Match(t, EnumType); + for (tag_t *tag = e->tags; tag; tag = tag->next) { + if (streq(field_name, tag->name)) + return Type(BoolType); + } + return NULL; + } case SetType: { if (streq(field_name, "length")) return INT_TYPE;