diff --git a/ast.c b/ast.c index acbe109..d39fd39 100644 --- a/ast.c +++ b/ast.c @@ -58,9 +58,7 @@ CORD arg_list_to_xml(arg_ast_t *args) { CORD when_clauses_to_xml(when_clause_t *clauses) { CORD c = CORD_EMPTY; for (; clauses; clauses = clauses->next) { - c = CORD_all(c, "var) c = CORD_all(c, " var=\"", Match(clauses->var, Var)->name, "\""); - c = CORD_all(c, " tag=\"", ast_to_xml(clauses->tag_name), "\">", ast_to_xml(clauses->body), ""); + c = CORD_all(c, "tag_name), "\">", ast_list_to_xml(clauses->args), ast_to_xml(clauses->body), ""); } return c; } diff --git a/ast.h b/ast.h index b0e924f..4307532 100644 --- a/ast.h +++ b/ast.h @@ -27,11 +27,6 @@ typedef struct ast_list_s { struct ast_list_s *next; } ast_list_t; -typedef struct when_clause_s { - ast_t *var, *tag_name, *body; - struct when_clause_s *next; -} when_clause_t; - typedef struct arg_ast_s { const char *name; type_ast_t *type; @@ -39,6 +34,13 @@ typedef struct arg_ast_s { struct arg_ast_s *next; } arg_ast_t; +typedef struct when_clause_s { + ast_t *tag_name; + ast_list_t *args; + ast_t *body; + struct when_clause_s *next; +} when_clause_t; + typedef enum { BINOP_UNKNOWN, BINOP_POWER=100, BINOP_MULT, BINOP_DIVIDE, BINOP_MOD, BINOP_MOD1, BINOP_PLUS, diff --git a/compile.c b/compile.c index dec8ea4..59ee0b0 100644 --- a/compile.c +++ b/compile.c @@ -182,10 +182,26 @@ CORD compile_statement(env_t *env, ast_t *ast) } assert(tag_type); env_t *scope = env; - if (clause->var) { - code = CORD_all(code, compile_type(env, tag_type), " ", compile(env, clause->var), " = subject.", clause_tag_name, ";\n"); - scope = fresh_scope(env); - set_binding(scope, Match(clause->var, Var)->name, new(binding_t, .type=tag_type)); + + auto tag_struct = Match(tag_type, StructType); + if (clause->args && !clause->args->next && tag_struct->fields && tag_struct->fields->next) { + code = CORD_all(code, compile_type(env, tag_type), " ", compile(env, clause->args->ast), " = subject.", clause_tag_name, ";\n"); + scope = fresh_scope(scope); + set_binding(scope, Match(clause->args->ast, Var)->name, new(binding_t, .type=tag_type)); + } else if (clause->args) { + scope = fresh_scope(scope); + ast_list_t *var = clause->args; + arg_t *field = tag_struct->fields; + while (var || field) { + if (!var) + code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, clause_tag_name, field->name); + if (!field) + code_err(var->ast, "This is one more field than %T has", subject_t); + code = CORD_all(code, compile_type(env, field->type), " ", compile(env, var->ast), " = subject.", clause_tag_name, ".", field->name, ";\n"); + set_binding(scope, Match(var->ast, Var)->name, new(binding_t, .type=field->type)); + var = var->next; + field = field->next; + } } code = CORD_all(code, compile_statement(scope, clause->body), "\nbreak;\n}\n"); } diff --git a/parse.c b/parse.c index 720c796..d942148 100644 --- a/parse.c +++ b/parse.c @@ -883,20 +883,27 @@ PARSER(parse_when) { while (get_indent(ctx->file, tmp) == starting_indent && match_word(&tmp, "is")) { pos = tmp; spaces(&pos); - ast_t *tag_name, *var = expect(ctx, start, &pos, parse_var, "I expected a variable or tag name here"); + ast_t *tag_name = expect(ctx, start, &pos, parse_var, "I expected a tag name here"); spaces(&pos); - if (match(&pos, ":")) { - spaces(&pos); - tag_name = optional(ctx, &pos, parse_var); - } else { - tag_name = var; - var = NULL; + ast_list_t *args = NULL; + if (match(&pos, "(")) { + for (;;) { + whitespace(&pos); + ast_t *arg = optional(ctx, &pos, parse_var); + if (!arg) break; + args = new(ast_list_t, .ast=arg, .next=args); + whitespace(&pos); + if (!match(&pos, ",")) break; + } + whitespace(&pos); + expect_closing(ctx, &pos, ")", "I was expecting a ')' to finish this pattern's arguments"); + REVERSE_LIST(args); } expect_str(ctx, start, &pos, ":", "I expected a ':' here"); ast_t *body = expect(ctx, start, &pos, parse_opt_indented_block, "I expected a body for this 'when' clause"); - clauses = new(when_clause_t, .var=var, .tag_name=tag_name, .body=body, .next=clauses); + clauses = new(when_clause_t, .tag_name=tag_name, .args=args, .body=body, .next=clauses); tmp = pos; whitespace(&tmp); } diff --git a/test/enums.tm b/test/enums.tm index 1bc2620..9f70041 100644 --- a/test/enums.tm +++ b/test/enums.tm @@ -1,4 +1,19 @@ -enum Foo(Zero, One(x:Int), Two(x,y:Int)) +enum Foo(Zero, One(x:Int), Two(x:Int, y:Int), Three(x:Int, y:Text, z:Bool), Four(x,y,z,w:Int), Last(t:Text)) + +func choose_text(f:Foo)->Text: + >> f + when f is Zero: + return "Zero" + is One(one): + return "One: {one}" + is Two(x, y): + return "Two: x={x}, y={y}" + is Three(three): + return "Three: {three}" + is Four: + return "Four" + else: + return "else: {f}" func main(): >> Foo.Zero @@ -27,9 +42,16 @@ func main(): >> t[Foo.Zero] = "missing" - when x is o:One: - >> o.x - = 123 - else: - fail("Oops") + >> choose_text(Foo.Zero) + = "Zero" + >> choose_text(Foo.One(123)) + = "One: 123" + >> choose_text(Foo.Two(123, 456)) + = "Two: x=123, y=456" + >> choose_text(Foo.Three(123, "hi", yes)) + = "Three: Three(x=123, y=\"hi\", z=yes)" + >> choose_text(Foo.Four(1,2,3,4)) + = "Four" + >> choose_text(Foo.Last("XX")) + = "else: Foo.Last(t=\"XX\")" diff --git a/typecheck.c b/typecheck.c index ad99a3b..b95ef41 100644 --- a/typecheck.c +++ b/typecheck.c @@ -862,9 +862,23 @@ type_t *get_type(env_t *env, ast_t *ast) code_err(clause->tag_name, "This is not a valid tag for the type %T", subject_t); env_t *scope = env; - if (clause->var) { + auto tag_struct = Match(tag_type, StructType); + if (clause->args && !clause->args->next && tag_struct->fields && tag_struct->fields->next) { scope = fresh_scope(scope); - set_binding(scope, Match(clause->var, Var)->name, new(binding_t, .type=tag_type)); + set_binding(scope, Match(clause->args->ast, Var)->name, new(binding_t, .type=tag_type)); + } else if (clause->args) { + scope = fresh_scope(scope); + ast_list_t *var = clause->args; + arg_t *field = tag_struct->fields; + while (var || field) { + if (!var) + code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, tag_name, field->name); + if (!field) + code_err(var->ast, "This is one more field than %T has", subject_t); + set_binding(scope, Match(var->ast, Var)->name, new(binding_t, .type=field->type)); + var = var->next; + field = field->next; + } } type_t *clause_type = get_type(scope, clause->body); type_t *merged = type_or_type(overall_t, clause_type);