aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compile.c20
-rw-r--r--enums.c24
-rw-r--r--test/enums.tm6
-rw-r--r--types.c8
4 files changed, 44 insertions, 14 deletions
diff --git a/compile.c b/compile.c
index fb60e387..17d19f30 100644
--- a/compile.c
+++ b/compile.c
@@ -270,7 +270,7 @@ CORD compile_statement(env_t *env, ast_t *ast)
auto enum_t = Match(subject_t, EnumType);
CORD code = CORD_all("{ ", compile_type(subject_t), " subject = ", compile(env, when->subject), ";\n"
- "switch (subject.$tag) {");
+ "switch (subject.tag) {");
for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
const char *clause_tag_name = Match(clause->tag_name, Var)->name;
code = CORD_all(code, "case ", namespace_prefix(enum_t->env->libname, enum_t->env->namespace), "tag$", clause_tag_name, ": {\n");
@@ -1077,7 +1077,7 @@ CORD compile_statement(env_t *env, ast_t *ast)
}
env_t *enum_env = Match(fn->ret, EnumType)->env;
- next_fn = CORD_all("(cur=", next_fn, iter_t->tag == ClosureType ? "(next.userdata)" : "()", ").$tag == ",
+ next_fn = CORD_all("(cur=", next_fn, iter_t->tag == ClosureType ? "(next.userdata)" : "()", ").tag == ",
namespace_prefix(enum_env->libname, enum_env->namespace), "tag$Next");
if (for_->empty) {
@@ -2610,6 +2610,22 @@ CORD compile(env_t *env, ast_t *ast)
}
code_err(ast, "The field '%s' is not a valid field name of %T", f->field, value_t);
}
+ case EnumType: {
+ auto e = Match(value_t, EnumType);
+ for (tag_t *tag = e->tags; tag; tag = tag->next) {
+ if (streq(f->field, tag->name)) {
+ CORD prefix = namespace_prefix(e->env->libname, e->env->namespace);
+ if (fielded_t->tag == PointerType) {
+ CORD fielded = compile_to_pointer_depth(env, f->fielded, 1, false);
+ return CORD_all("((", fielded, ")->tag == ", prefix, "tag$", tag->name, ")");
+ } else {
+ CORD fielded = compile(env, f->fielded);
+ return CORD_all("((", fielded, ").tag == ", prefix, "tag$", tag->name, ")");
+ }
+ }
+ }
+ code_err(ast, "The field '%s' is not a valid tag name of %T", f->field, value_t);
+ }
case ArrayType: {
if (streq(f->field, "length"))
return CORD_all("Int64_to_Int((", compile_to_pointer_depth(env, f->fielded, 0, false), ").length)");
diff --git a/enums.c b/enums.c
index 10446f52..51740549 100644
--- a/enums.c
+++ b/enums.c
@@ -26,7 +26,7 @@ static CORD compile_str_method(env_t *env, ast_t *ast)
CORD full_name = CORD_cat(namespace_prefix(env->libname, env->namespace), def->name);
CORD str_func = CORD_all("static CORD ", full_name, "$as_text(", full_name, "_t *obj, bool use_color) {\n"
"\tif (!obj) return \"", def->name, "\";\n"
- "switch (obj->$tag) {\n");
+ "switch (obj->tag) {\n");
for (tag_ast_t *tag = def->tags; tag; tag = tag->next) {
if (!tag->fields) {
str_func = CORD_all(str_func, "\tcase ", full_name, "$tag$", tag->name, ": return use_color ? \"\\x1b[36;1m",
@@ -63,15 +63,15 @@ static CORD compile_compare_method(env_t *env, ast_t *ast)
return CORD_all("static int ", full_name, "$compare(const ", full_name, "_t *x, const ", full_name,
"_t *y, const TypeInfo *info) {\n"
"(void)info;\n"
- "return (x->$tag - y->$tag);\n"
+ "return (x->tag - y->tag);\n"
"}\n");
}
CORD cmp_func = CORD_all("static int ", full_name, "$compare(const ", full_name, "_t *x, const ", full_name,
"_t *y, const TypeInfo *info) {\n"
"(void)info;\n"
- "int diff = x->$tag - y->$tag;\n"
+ "int diff = x->tag - y->tag;\n"
"if (diff) return diff;\n"
- "switch (x->$tag) {\n");
+ "switch (x->tag) {\n");
for (tag_ast_t *tag = def->tags; tag; tag = tag->next) {
if (tag->fields) {
type_t *tag_type = Table$str_get(*env->types, CORD_to_const_char_star(CORD_all(def->name, "$", tag->name)));
@@ -94,14 +94,14 @@ static CORD compile_equals_method(env_t *env, ast_t *ast)
return CORD_all("static bool ", full_name, "$equal(const ", full_name, "_t *x, const ", full_name,
"_t *y, const TypeInfo *info) {\n"
"(void)info;\n"
- "return (x->$tag == y->$tag);\n"
+ "return (x->tag == y->tag);\n"
"}\n");
}
CORD eq_func = CORD_all("static bool ", full_name, "$equal(const ", full_name, "_t *x, const ", full_name,
"_t *y, const TypeInfo *info) {\n"
"(void)info;\n"
- "if (x->$tag != y->$tag) return no;\n"
- "switch (x->$tag) {\n");
+ "if (x->tag != y->tag) return no;\n"
+ "switch (x->tag) {\n");
for (tag_ast_t *tag = def->tags; tag; tag = tag->next) {
if (tag->fields) {
type_t *tag_type = Table$str_get(*env->types, CORD_to_const_char_star(CORD_all(def->name, "$", tag->name)));
@@ -124,14 +124,14 @@ static CORD compile_hash_method(env_t *env, ast_t *ast)
return CORD_all("static uint32_t ", full_name, "$hash(const ", full_name, "_t *obj, const TypeInfo *info) {\n"
"(void)info;\n"
"uint32_t hash;\n"
- "halfsiphash(&obj->$tag, sizeof(obj->$tag), TOMO_HASH_KEY, (uint8_t*)&hash, sizeof(hash));\n"
+ "halfsiphash(&obj->tag, sizeof(obj->tag), TOMO_HASH_KEY, (uint8_t*)&hash, sizeof(hash));\n"
"return hash;"
"\n}\n");
}
CORD hash_func = CORD_all("static uint32_t ", full_name, "$hash(const ", full_name, "_t *obj, const TypeInfo *info) {\n"
"(void)info;\n"
- "uint32_t hashes[2] = {(uint32_t)obj->$tag, 0};\n"
- "switch (obj->$tag) {\n");
+ "uint32_t hashes[2] = {(uint32_t)obj->tag, 0};\n"
+ "switch (obj->tag) {\n");
for (tag_ast_t *tag = def->tags; tag; tag = tag->next) {
if (tag->fields) {
type_t *tag_type = Table$str_get(*env->types, CORD_to_const_char_star(CORD_all(def->name, "$", tag->name)));
@@ -164,7 +164,7 @@ void compile_enum_def(env_t *env, ast_t *ast)
}
if (arg_sig == CORD_EMPTY) arg_sig = "void";
CORD constructor_impl = CORD_all("public inline ", full_name, "_t ", full_name, "$tagged$", tag->name, "(", arg_sig, ") { return (",
- full_name, "_t){.$tag=", full_name, "$tag$", tag->name, ", .$", tag->name, "={");
+ full_name, "_t){.tag=", full_name, "$tag$", tag->name, ", .$", tag->name, "={");
for (arg_ast_t *field = tag->fields; field; field = field->next) {
constructor_impl = CORD_all(constructor_impl, "$", field->name);
if (field->next) constructor_impl = CORD_cat(constructor_impl, ", ");
@@ -211,7 +211,7 @@ CORD compile_enum_typedef(env_t *env, ast_t *ast)
CORD_appendf(&enum_def, "%r$tag$%s = %ld", full_name, tag->name, tag->value);
if (tag->next) enum_def = CORD_cat(enum_def, ", ");
}
- enum_def = CORD_cat(enum_def, "} $tag;\n"
+ enum_def = CORD_cat(enum_def, "} tag;\n"
"union {\n");
for (tag_ast_t *tag = def->tags; tag; tag = tag->next) {
CORD field_def = compile_struct_typedef(env, WrapAST(ast, StructDef, .name=CORD_to_const_char_star(CORD_all(def->name, "$", tag->name)), .fields=tag->fields));
diff --git a/test/enums.tm b/test/enums.tm
index b734d487..f4af342d 100644
--- a/test/enums.tm
+++ b/test/enums.tm
@@ -23,6 +23,12 @@ func main():
>> Foo.Two(123, 456)
= Foo.Two(x=123, y=456)
+ >> one := Foo.One(123)
+ >> one.One
+ = yes
+ >> one.Two
+ = no
+
>> Foo.One(10) == Foo.One(10)
= yes
diff --git a/types.c b/types.c
index cd388304..8cb534d7 100644
--- a/types.c
+++ b/types.c
@@ -594,6 +594,14 @@ type_t *get_field_type(type_t *t, const char *field_name)
}
return NULL;
}
+ case EnumType: {
+ auto e = Match(t, EnumType);
+ for (tag_t *tag = e->tags; tag; tag = tag->next) {
+ if (streq(field_name, tag->name))
+ return Type(BoolType);
+ }
+ return NULL;
+ }
case SetType: {
if (streq(field_name, "length"))
return INT_TYPE;