diff options
| author | Bruce Hill <bruce@bruce-hill.com> | 2025-03-06 18:37:08 -0500 |
|---|---|---|
| committer | Bruce Hill <bruce@bruce-hill.com> | 2025-03-06 18:37:08 -0500 |
| commit | 73e559fbe4182828742ac1b1d108bcdc42bc46d6 (patch) | |
| tree | 22c4d6245114cf3ba72701dafe18b84ce72e3f66 | |
| parent | 09423f6d42d86c62beefa4607fba41e3698f1850 (diff) | |
Support 'when' for literal values with equality checking
| -rw-r--r-- | ast.c | 2 | ||||
| -rw-r--r-- | ast.h | 4 | ||||
| -rw-r--r-- | compile.c | 120 | ||||
| -rw-r--r-- | parse.c | 30 | ||||
| -rw-r--r-- | typecheck.c | 91 | ||||
| -rw-r--r-- | typecheck.h | 1 |
6 files changed, 148 insertions, 100 deletions
@@ -71,7 +71,7 @@ CORD arg_list_to_xml(arg_ast_t *args) { CORD when_clauses_to_xml(when_clause_t *clauses) { CORD c = CORD_EMPTY; for (; clauses; clauses = clauses->next) { - c = CORD_all(c, "<case tag=\"", ast_to_xml(clauses->tag_name), "\">", ast_list_to_xml(clauses->args), ast_to_xml(clauses->body), "</case>"); + c = CORD_all(c, "<case>", ast_to_xml(clauses->pattern), ast_to_xml(clauses->body), "</case>"); } return c; } @@ -52,9 +52,7 @@ typedef struct arg_ast_s { } arg_ast_t; typedef struct when_clause_s { - ast_t *tag_name; - ast_list_t *args; - ast_t *body; + ast_t *pattern, *body; struct when_clause_s *next; } when_clause_t; @@ -360,9 +360,27 @@ static void add_closed_vars(Table_t *closed_vars, env_t *enclosing_scope, env_t add_closed_vars(closed_vars, enclosing_scope, env, when->subject); type_t *subject_t = get_type(env, when->subject); + if (subject_t->tag != EnumType) { + for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { + add_closed_vars(closed_vars, enclosing_scope, env, clause->pattern); + add_closed_vars(closed_vars, enclosing_scope, env, clause->body); + } + + if (when->else_body) + add_closed_vars(closed_vars, enclosing_scope, env, when->else_body); + return; + } + auto enum_t = Match(subject_t, EnumType); for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { - const char *clause_tag_name = Match(clause->tag_name, Var)->name; + const char *clause_tag_name; + if (clause->pattern->tag == Var) + clause_tag_name = Match(clause->pattern, Var)->name; + else if (clause->pattern->tag == FunctionCall && Match(clause->pattern, FunctionCall)->fn->tag == Var) + clause_tag_name = Match(Match(clause->pattern, FunctionCall)->fn, Var)->name; + else + code_err(clause->pattern, "This is not a valid pattern for a %T enum", subject_t); + type_t *tag_type = NULL; for (tag_t *tag = enum_t->tags; tag; tag = tag->next) { if (streq(tag->name, clause_tag_name)) { @@ -371,26 +389,7 @@ static void add_closed_vars(Table_t *closed_vars, env_t *enclosing_scope, env_t } } assert(tag_type); - env_t *scope = env; - - auto tag_struct = Match(tag_type, StructType); - if (clause->args && !clause->args->next && tag_struct->fields && tag_struct->fields->next) { - scope = fresh_scope(scope); - set_binding(scope, Match(clause->args->ast, Var)->name, tag_type, CORD_EMPTY); - } else if (clause->args) { - scope = fresh_scope(scope); - ast_list_t *var = clause->args; - arg_t *field = tag_struct->fields; - while (var || field) { - if (!var) - code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, clause_tag_name, field->name); - if (!field) - code_err(var->ast, "This is one more field than %T has", subject_t); - set_binding(scope, Match(var->ast, Var)->name, field->type, CORD_EMPTY); - var = var->next; - field = field->next; - } - } + env_t *scope = when_clause_scope(env, subject_t, clause); add_closed_vars(closed_vars, enclosing_scope, scope, clause->body); } if (when->else_body) @@ -752,11 +751,44 @@ static CORD _compile_statement(env_t *env, ast_t *ast) auto when = Match(ast, When); type_t *subject_t = get_type(env, when->subject); + if (subject_t->tag != EnumType) { + CORD prefix = CORD_EMPTY, suffix = CORD_EMPTY; + ast_t *subject = when->subject; + if (!is_idempotent(when->subject)) { + prefix = CORD_all("{\n", compile_declaration(subject_t, "_when_subject"), " = ", compile(env, subject), ";\n"); + suffix = "}\n"; + subject = WrapAST(subject, InlineCCode, .type=subject_t, .code="_when_subject"); + } + + CORD code = CORD_EMPTY; + for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { + ast_t *comparison = WrapAST(clause->pattern, BinaryOp, .lhs=subject, .op=BINOP_EQ, .rhs=clause->pattern); + if (code != CORD_EMPTY) + code = CORD_all(code, "else "); + code = CORD_all(code, "if (", compile(env, comparison), ")", compile_statement(env, clause->body)); + } + if (when->else_body) + code = CORD_all(code, "else ", compile_statement(env, when->else_body)); + code = CORD_all(prefix, code, suffix); + return code; + } + auto enum_t = Match(subject_t, EnumType); CORD code = CORD_all("{ ", compile_type(subject_t), " subject = ", compile(env, when->subject), ";\n" "switch (subject.tag) {"); for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { - const char *clause_tag_name = Match(clause->tag_name, Var)->name; + if (clause->pattern->tag == Var) { + const char *clause_tag_name = Match(clause->pattern, Var)->name; + code = CORD_all(code, "case ", namespace_prefix(enum_t->env, enum_t->env->namespace), "tag$", clause_tag_name, ": {\n", + compile_statement(env, clause->body), + "}\n"); + continue; + } + + if (clause->pattern->tag != FunctionCall || Match(clause->pattern, FunctionCall)->fn->tag != Var) + code_err(clause->pattern, "This is not a valid pattern for a %T enum type", subject_t); + + const char *clause_tag_name = Match(Match(clause->pattern, FunctionCall)->fn, Var)->name; code = CORD_all(code, "case ", namespace_prefix(enum_t->env, enum_t->env->namespace), "tag$", clause_tag_name, ": {\n"); type_t *tag_type = NULL; for (tag_t *tag = enum_t->tags; tag; tag = tag->next) { @@ -769,22 +801,32 @@ static CORD _compile_statement(env_t *env, ast_t *ast) env_t *scope = env; auto tag_struct = Match(tag_type, StructType); - if (clause->args && !clause->args->next && tag_struct->fields && tag_struct->fields->next) { - code = CORD_all(code, compile_declaration(tag_type, compile(env, clause->args->ast)), " = subject.$", clause_tag_name, ";\n"); - scope = fresh_scope(scope); - set_binding(scope, Match(clause->args->ast, Var)->name, tag_type, CORD_EMPTY); - } else if (clause->args) { + arg_ast_t *args = Match(clause->pattern, FunctionCall)->args; + if (args && !args->next && tag_struct->fields && tag_struct->fields->next) { + if (args->value->tag != Var) + code_err(args->value, "This is not a valid variable to bind to"); + const char *var_name = Match(args->value, Var)->name; + if (!streq(var_name, "_")) { + code = CORD_all(code, compile_declaration(tag_type, compile(env, args->value)), " = subject.$", clause_tag_name, ";\n"); + scope = fresh_scope(scope); + set_binding(scope, Match(args->value, Var)->name, tag_type, CORD_EMPTY); + } + } else if (args) { scope = fresh_scope(scope); - ast_list_t *var = clause->args; arg_t *field = tag_struct->fields; - while (var || field) { - if (!var) - code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, clause_tag_name, field->name); + for (arg_ast_t *arg = args; arg || field; arg = arg->next) { + if (!arg) + code_err(ast, "The field %T.%s.%s wasn't accounted for", subject_t, clause_tag_name, field->name); if (!field) - code_err(var->ast, "This is one more field than %T has", subject_t); - code = CORD_all(code, compile_declaration(field->type, compile(env, var->ast)), " = subject.$", clause_tag_name, ".$", field->name, ";\n"); - set_binding(scope, Match(var->ast, Var)->name, field->type, CORD_EMPTY); - var = var->next; + code_err(arg->value, "This is one more field than %T has", subject_t); + if (arg->name) + code_err(arg->value, "Named arguments are not currently supported"); + + const char *var_name = Match(arg->value, Var)->name; + if (!streq(var_name, "_")) { + code = CORD_all(code, compile_declaration(field->type, compile(env, arg->value)), " = subject.$", clause_tag_name, ".$", field->name, ";\n"); + set_binding(scope, Match(arg->value, Var)->name, field->type, CORD_EMPTY); + } field = field->next; } } @@ -1160,7 +1202,7 @@ static CORD _compile_statement(env_t *env, ast_t *ast) const char *target = Match(ast, Skip)->target; for (loop_ctx_t *ctx = env->loop_ctx; ctx; ctx = ctx->next) { bool matched = !target || CORD_cmp(target, ctx->loop_name) == 0; - for (ast_list_t *var = ctx->loop_vars; var && !matched; var = var->next) + for (ast_list_t *var = ctx->loop_vars; var && !matched; var = var ? var->next : NULL) matched = (CORD_cmp(target, Match(var->ast, Var)->name) == 0); if (matched) { @@ -1189,7 +1231,7 @@ static CORD _compile_statement(env_t *env, ast_t *ast) const char *target = Match(ast, Stop)->target; for (loop_ctx_t *ctx = env->loop_ctx; ctx; ctx = ctx->next) { bool matched = !target || CORD_cmp(target, ctx->loop_name) == 0; - for (ast_list_t *var = ctx->loop_vars; var && !matched; var = var->next) + for (ast_list_t *var = ctx->loop_vars; var && !matched; var = var ? var->next : var) matched = (CORD_cmp(target, Match(var->ast, Var)->name) == 0); if (matched) { @@ -3443,12 +3485,12 @@ CORD compile(env_t *env, ast_t *ast) for (when_clause_t *clause = original->clauses; clause; clause = clause->next) { type_t *clause_type = get_clause_type(env, subject_t, clause); if (clause_type->tag == AbortType || clause_type->tag == ReturnType) { - new_clauses = new(when_clause_t, .tag_name=clause->tag_name, .args=clause->args, .body=clause->body, .next=new_clauses); + new_clauses = new(when_clause_t, .pattern=clause->pattern, .body=clause->body, .next=new_clauses); } else { ast_t *assign = WrapAST(clause->body, Assign, .targets=new(ast_list_t, .ast=when_var), .values=new(ast_list_t, .ast=clause->body)); - new_clauses = new(when_clause_t, .tag_name=clause->tag_name, .args=clause->args, .body=assign, .next=new_clauses); + new_clauses = new(when_clause_t, .pattern=clause->pattern, .body=assign, .next=new_clauses); } } REVERSE_LIST(new_clauses); @@ -1115,37 +1115,13 @@ PARSER(parse_when) { while (get_indent(ctx, tmp) == starting_indent && match_word(&tmp, "is")) { pos = tmp; spaces(&pos); - ast_t *tag_name; - ast_list_t *args; - if (match(&pos, "@")) { - tag_name = NewAST(ctx->file, pos-1, pos, Var, .name="@"); - spaces(&pos); - ast_t *arg = optional(ctx, &pos, parse_var); - args = arg ? new(ast_list_t, .ast=arg) : NULL; - } else { - tag_name = expect(ctx, start, &pos, parse_var, "I expected a tag name here"); - spaces(&pos); - args = NULL; - if (match(&pos, "(")) { - for (;;) { - whitespace(&pos); - ast_t *arg = optional(ctx, &pos, parse_var); - if (!arg) break; - args = new(ast_list_t, .ast=arg, .next=args); - whitespace(&pos); - if (!match(&pos, ",")) break; - } - whitespace(&pos); - expect_closing(ctx, &pos, ")", "I was expecting a ')' to finish this pattern's arguments"); - REVERSE_LIST(args); - } - } - + ast_t *pattern = expect(ctx, start, &pos, parse_expr, "I expected a pattern to match here"); + spaces(&pos); tmp = pos; if (!match(&tmp, ":")) parser_err(ctx, tmp, tmp, "I expected a colon ':' after this clause"); ast_t *body = expect(ctx, start, &pos, parse_block, "I expected a body for this 'when' clause"); - clauses = new(when_clause_t, .tag_name=tag_name, .args=args, .body=body, .next=clauses); + clauses = new(when_clause_t, .pattern=pattern, .body=body, .next=clauses); tmp = pos; whitespace(&tmp); } diff --git a/typecheck.c b/typecheck.c index 99a9a78a..1ab20177 100644 --- a/typecheck.c +++ b/typecheck.c @@ -484,13 +484,18 @@ type_t *get_method_type(env_t *env, ast_t *self, const char *name) return b->type; } -type_t *get_clause_type(env_t *env, type_t *subject_t, when_clause_t *clause) +env_t *when_clause_scope(env_t *env, type_t *subject_t, when_clause_t *clause) { - assert(subject_t->tag == EnumType); - tag_t * const tags = Match(subject_t, EnumType)->tags; + if (clause->pattern->tag == Var || subject_t->tag != EnumType) + return env; + + if (clause->pattern->tag != FunctionCall || Match(clause->pattern, FunctionCall)->fn->tag != Var) + code_err(clause->pattern, "I only support variables and constructors for pattern matching %T types in a 'when' block", subject_t); - const char *tag_name = Match(clause->tag_name, Var)->name; + auto fn = Match(clause->pattern, FunctionCall); + const char *tag_name = Match(fn->fn, Var)->name; type_t *tag_type = NULL; + tag_t * const tags = Match(subject_t, EnumType)->tags; for (tag_t *tag = tags; tag; tag = tag->next) { if (streq(tag->name, tag_name)) { tag_type = tag->type; @@ -499,29 +504,38 @@ type_t *get_clause_type(env_t *env, type_t *subject_t, when_clause_t *clause) } if (!tag_type) - code_err(clause->tag_name, "There is no tag '%s' for the type %T", tag_name, subject_t); + code_err(clause->pattern, "There is no tag '%s' for the type %T", tag_name, subject_t); - // Don't return early so we validate the tags - if (!clause->args) - return get_type(env, clause->body); + if (!fn->args) + return env; env_t *scope = fresh_scope(env); auto tag_struct = Match(tag_type, StructType); - if (!clause->args->next && tag_struct->fields && tag_struct->fields->next) { - set_binding(scope, Match(clause->args->ast, Var)->name, tag_type, CORD_EMPTY); - } else { - ast_list_t *var = clause->args; - arg_t *field = tag_struct->fields; - while (var || field) { - if (!var) - code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, tag_name, field->name); - if (!field) - code_err(var->ast, "This is one more field than %T has", subject_t); - set_binding(scope, Match(var->ast, Var)->name, field->type, CORD_EMPTY); - var = var->next; - field = field->next; - } - } + if (fn->args && !fn->args->next && tag_struct->fields && tag_struct->fields->next) { + if (fn->args->value->tag != Var) + code_err(fn->args->value, "I expected a variable here"); + set_binding(scope, Match(fn->args->value, Var)->name, tag_type, CORD_EMPTY); + return scope; + } + + arg_t *field = tag_struct->fields; + for (arg_ast_t *var = fn->args; var || field; var = var ? var->next : var) { + if (!var) + code_err(clause->pattern, "The field %T.%s.%s wasn't accounted for", subject_t, tag_name, field->name); + if (!field) + code_err(var->value, "This is one more field than %T has", subject_t); + if (var->value->tag != Var) + code_err(var->value, "I expected this to be a plain variable so I could bind it to a value"); + if (!streq(Match(var->value, Var)->name, "_")) + set_binding(scope, Match(var->value, Var)->name, field->type, CORD_EMPTY); + field = field->next; + } + return scope; +} + +type_t *get_clause_type(env_t *env, type_t *subject_t, when_clause_t *clause) +{ + env_t *scope = when_clause_scope(env, subject_t, clause); return get_type(scope, clause->body); } @@ -1258,10 +1272,19 @@ type_t *get_type(env_t *env, ast_t *ast) case When: { auto when = Match(ast, When); type_t *subject_t = get_type(env, when->subject); - type_t *overall_t = NULL; - if (subject_t->tag != EnumType) - code_err(when->subject, "'when' statements are only for enum types and optional pointers, not %T", subject_t); + if (subject_t->tag != EnumType) { + type_t *t = NULL; + for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { + t = type_or_type(t, get_type(env, clause->body)); + } + if (when->else_body) + t = type_or_type(t, get_type(env, when->else_body)); + else if (t->tag != OptionalType) + t = Type(OptionalType, .type=t); + return t; + } + type_t *overall_t = NULL; tag_t * const tags = Match(subject_t, EnumType)->tags; typedef struct match_s { @@ -1274,12 +1297,19 @@ type_t *get_type(env_t *env, ast_t *ast) matches = new(match_t, .tag=tag, .handled=false, .next=matches); for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { - const char *tag_name = Match(clause->tag_name, Var)->name; + const char *tag_name; + if (clause->pattern->tag == Var) + tag_name = Match(clause->pattern, Var)->name; + else if (clause->pattern->tag == FunctionCall && Match(clause->pattern, FunctionCall)->fn->tag == Var) + tag_name = Match(Match(clause->pattern, FunctionCall)->fn, Var)->name; + else + code_err(clause->pattern, "This is not a valid pattern for a %T enum", subject_t); + CORD valid_tags = CORD_EMPTY; for (match_t *m = matches; m; m = m->next) { if (streq(m->tag->name, tag_name)) { if (m->handled) - code_err(clause->tag_name, "This tag was already handled earlier"); + code_err(clause->pattern, "This tag was already handled earlier"); m->handled = true; goto found_matching_tag; } @@ -1287,13 +1317,14 @@ type_t *get_type(env_t *env, ast_t *ast) valid_tags = CORD_cat(valid_tags, m->tag->name); } - code_err(clause->tag_name, "There is no tag '%s' for the type %T (valid tags: %s)", + code_err(clause->pattern, "There is no tag '%s' for the type %T (valid tags: %s)", tag_name, subject_t, CORD_to_char_star(valid_tags)); found_matching_tag:; } for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { - type_t *clause_type = get_clause_type(env, subject_t, clause); + env_t *clause_scope = when_clause_scope(env, subject_t, clause); + type_t *clause_type = get_type(clause_scope, clause->body); type_t *merged = type_or_type(overall_t, clause_type); if (!merged) code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T", diff --git a/typecheck.h b/typecheck.h index 1175e5fd..cc5cb18c 100644 --- a/typecheck.h +++ b/typecheck.h @@ -21,6 +21,7 @@ PUREFUNC bool is_discardable(env_t *env, ast_t *ast); type_t *get_function_def_type(env_t *env, ast_t *ast); type_t *get_arg_type(env_t *env, arg_t *arg); type_t *get_arg_ast_type(env_t *env, arg_ast_t *arg); +env_t *when_clause_scope(env_t *env, type_t *subject_t, when_clause_t *clause); type_t *get_clause_type(env_t *env, type_t *subject_t, when_clause_t *clause); PUREFUNC bool can_be_mutated(env_t *env, ast_t *ast); type_t *parse_type_string(env_t *env, const char *str); |
