diff --git a/ast.c b/ast.c
index acbe109..d39fd39 100644
--- a/ast.c
+++ b/ast.c
@@ -58,9 +58,7 @@ CORD arg_list_to_xml(arg_ast_t *args) {
CORD when_clauses_to_xml(when_clause_t *clauses) {
CORD c = CORD_EMPTY;
for (; clauses; clauses = clauses->next) {
- c = CORD_all(c, "var) c = CORD_all(c, " var=\"", Match(clauses->var, Var)->name, "\"");
- c = CORD_all(c, " tag=\"", ast_to_xml(clauses->tag_name), "\">", ast_to_xml(clauses->body), "");
+ c = CORD_all(c, "tag_name), "\">", ast_list_to_xml(clauses->args), ast_to_xml(clauses->body), "");
}
return c;
}
diff --git a/ast.h b/ast.h
index b0e924f..4307532 100644
--- a/ast.h
+++ b/ast.h
@@ -27,11 +27,6 @@ typedef struct ast_list_s {
struct ast_list_s *next;
} ast_list_t;
-typedef struct when_clause_s {
- ast_t *var, *tag_name, *body;
- struct when_clause_s *next;
-} when_clause_t;
-
typedef struct arg_ast_s {
const char *name;
type_ast_t *type;
@@ -39,6 +34,13 @@ typedef struct arg_ast_s {
struct arg_ast_s *next;
} arg_ast_t;
+typedef struct when_clause_s {
+ ast_t *tag_name;
+ ast_list_t *args;
+ ast_t *body;
+ struct when_clause_s *next;
+} when_clause_t;
+
typedef enum {
BINOP_UNKNOWN,
BINOP_POWER=100, BINOP_MULT, BINOP_DIVIDE, BINOP_MOD, BINOP_MOD1, BINOP_PLUS,
diff --git a/compile.c b/compile.c
index dec8ea4..59ee0b0 100644
--- a/compile.c
+++ b/compile.c
@@ -182,10 +182,26 @@ CORD compile_statement(env_t *env, ast_t *ast)
}
assert(tag_type);
env_t *scope = env;
- if (clause->var) {
- code = CORD_all(code, compile_type(env, tag_type), " ", compile(env, clause->var), " = subject.", clause_tag_name, ";\n");
- scope = fresh_scope(env);
- set_binding(scope, Match(clause->var, Var)->name, new(binding_t, .type=tag_type));
+
+ auto tag_struct = Match(tag_type, StructType);
+ if (clause->args && !clause->args->next && tag_struct->fields && tag_struct->fields->next) {
+ code = CORD_all(code, compile_type(env, tag_type), " ", compile(env, clause->args->ast), " = subject.", clause_tag_name, ";\n");
+ scope = fresh_scope(scope);
+ set_binding(scope, Match(clause->args->ast, Var)->name, new(binding_t, .type=tag_type));
+ } else if (clause->args) {
+ scope = fresh_scope(scope);
+ ast_list_t *var = clause->args;
+ arg_t *field = tag_struct->fields;
+ while (var || field) {
+ if (!var)
+ code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, clause_tag_name, field->name);
+ if (!field)
+ code_err(var->ast, "This is one more field than %T has", subject_t);
+ code = CORD_all(code, compile_type(env, field->type), " ", compile(env, var->ast), " = subject.", clause_tag_name, ".", field->name, ";\n");
+ set_binding(scope, Match(var->ast, Var)->name, new(binding_t, .type=field->type));
+ var = var->next;
+ field = field->next;
+ }
}
code = CORD_all(code, compile_statement(scope, clause->body), "\nbreak;\n}\n");
}
diff --git a/parse.c b/parse.c
index 720c796..d942148 100644
--- a/parse.c
+++ b/parse.c
@@ -883,20 +883,27 @@ PARSER(parse_when) {
while (get_indent(ctx->file, tmp) == starting_indent && match_word(&tmp, "is")) {
pos = tmp;
spaces(&pos);
- ast_t *tag_name, *var = expect(ctx, start, &pos, parse_var, "I expected a variable or tag name here");
+ ast_t *tag_name = expect(ctx, start, &pos, parse_var, "I expected a tag name here");
spaces(&pos);
- if (match(&pos, ":")) {
- spaces(&pos);
- tag_name = optional(ctx, &pos, parse_var);
- } else {
- tag_name = var;
- var = NULL;
+ ast_list_t *args = NULL;
+ if (match(&pos, "(")) {
+ for (;;) {
+ whitespace(&pos);
+ ast_t *arg = optional(ctx, &pos, parse_var);
+ if (!arg) break;
+ args = new(ast_list_t, .ast=arg, .next=args);
+ whitespace(&pos);
+ if (!match(&pos, ",")) break;
+ }
+ whitespace(&pos);
+ expect_closing(ctx, &pos, ")", "I was expecting a ')' to finish this pattern's arguments");
+ REVERSE_LIST(args);
}
expect_str(ctx, start, &pos, ":", "I expected a ':' here");
ast_t *body = expect(ctx, start, &pos, parse_opt_indented_block, "I expected a body for this 'when' clause");
- clauses = new(when_clause_t, .var=var, .tag_name=tag_name, .body=body, .next=clauses);
+ clauses = new(when_clause_t, .tag_name=tag_name, .args=args, .body=body, .next=clauses);
tmp = pos;
whitespace(&tmp);
}
diff --git a/test/enums.tm b/test/enums.tm
index 1bc2620..9f70041 100644
--- a/test/enums.tm
+++ b/test/enums.tm
@@ -1,4 +1,19 @@
-enum Foo(Zero, One(x:Int), Two(x,y:Int))
+enum Foo(Zero, One(x:Int), Two(x:Int, y:Int), Three(x:Int, y:Text, z:Bool), Four(x,y,z,w:Int), Last(t:Text))
+
+func choose_text(f:Foo)->Text:
+ >> f
+ when f is Zero:
+ return "Zero"
+ is One(one):
+ return "One: {one}"
+ is Two(x, y):
+ return "Two: x={x}, y={y}"
+ is Three(three):
+ return "Three: {three}"
+ is Four:
+ return "Four"
+ else:
+ return "else: {f}"
func main():
>> Foo.Zero
@@ -27,9 +42,16 @@ func main():
>> t[Foo.Zero]
= "missing"
- when x is o:One:
- >> o.x
- = 123
- else:
- fail("Oops")
+ >> choose_text(Foo.Zero)
+ = "Zero"
+ >> choose_text(Foo.One(123))
+ = "One: 123"
+ >> choose_text(Foo.Two(123, 456))
+ = "Two: x=123, y=456"
+ >> choose_text(Foo.Three(123, "hi", yes))
+ = "Three: Three(x=123, y=\"hi\", z=yes)"
+ >> choose_text(Foo.Four(1,2,3,4))
+ = "Four"
+ >> choose_text(Foo.Last("XX"))
+ = "else: Foo.Last(t=\"XX\")"
diff --git a/typecheck.c b/typecheck.c
index ad99a3b..b95ef41 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -862,9 +862,23 @@ type_t *get_type(env_t *env, ast_t *ast)
code_err(clause->tag_name, "This is not a valid tag for the type %T", subject_t);
env_t *scope = env;
- if (clause->var) {
+ auto tag_struct = Match(tag_type, StructType);
+ if (clause->args && !clause->args->next && tag_struct->fields && tag_struct->fields->next) {
scope = fresh_scope(scope);
- set_binding(scope, Match(clause->var, Var)->name, new(binding_t, .type=tag_type));
+ set_binding(scope, Match(clause->args->ast, Var)->name, new(binding_t, .type=tag_type));
+ } else if (clause->args) {
+ scope = fresh_scope(scope);
+ ast_list_t *var = clause->args;
+ arg_t *field = tag_struct->fields;
+ while (var || field) {
+ if (!var)
+ code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, tag_name, field->name);
+ if (!field)
+ code_err(var->ast, "This is one more field than %T has", subject_t);
+ set_binding(scope, Match(var->ast, Var)->name, new(binding_t, .type=field->type));
+ var = var->next;
+ field = field->next;
+ }
}
type_t *clause_type = get_type(scope, clause->body);
type_t *merged = type_or_type(overall_t, clause_type);