From 2ecd8e11fd9edc42f8593edf334dc54d3a2d6930 Mon Sep 17 00:00:00 2001 From: Bruce Hill Date: Thu, 22 Feb 2024 12:45:12 -0500 Subject: [PATCH] Implement 'when' statement for matching on enums --- Makefile | 1 - ast.c | 18 +++++++++--- ast.h | 12 +++++++- compile.c | 42 ++++++++++++++++++++++++---- parse.c | 52 ++++++++++++++++++++++++++++++++-- typecheck.c | 80 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 192 insertions(+), 13 deletions(-) diff --git a/Makefile b/Makefile index b24133d..94dd0ec 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,3 @@ -CC=gcc PREFIX=/usr/local VERSION=0.12.1 CCONFIG=-std=c11 -Werror -D_XOPEN_SOURCE=700 -D_POSIX_C_SOURCE=200809L -fPIC -ftrapv -fvisibility=hidden \ diff --git a/ast.c b/ast.c index 753066c..cabb0c6 100644 --- a/ast.c +++ b/ast.c @@ -20,6 +20,7 @@ static CORD ast_to_cord(ast_t *ast); static CORD ast_list_to_cord(ast_list_t *asts); static CORD type_ast_to_cord(type_ast_t *t); static CORD arg_list_to_cord(arg_ast_t *args); +static CORD when_clauses_to_cord(when_clause_t *clauses); static CORD tags_to_cord(tag_ast_t *tags); #define TO_CORD(x) _Generic(x, \ @@ -60,8 +61,17 @@ CORD arg_list_to_cord(arg_ast_t *args) { CORD_sprintf(&c, "%r=%s", c, ast_to_cord(args->default_val)); if (args->next) c = CORD_cat(c, ", "); } - c = CORD_cat(c, ")"); - return c; + return CORD_cat(c, ")"); +} + +CORD when_clauses_to_cord(when_clause_t *clauses) { + CORD c = "Clauses("; + for (; clauses; clauses = clauses->next) { + if (clauses->var) c = CORD_all(c, ast_to_cord(clauses->var), ":"); + c = CORD_all(c, ast_to_cord(clauses->tag_name), "=", ast_to_cord(clauses->body)); + if (clauses->next) c = CORD_cat(c, ", "); + } + return CORD_cat(c, ")"); } CORD tags_to_cord(tag_ast_t *tags) { @@ -72,8 +82,7 @@ CORD tags_to_cord(tag_ast_t *tags) { CORD_sprintf(&c, "%r(%r)=%ld", c, arg_list_to_cord(tags->fields), tags->value); if (tags->next) c = CORD_cat(c, ", "); } - c = CORD_cat(c, ")"); - return c; + return CORD_cat(c, ")"); } CORD ast_to_cord(ast_t *ast) @@ -117,6 +126,7 @@ CORD ast_to_cord(ast_t *ast) ast_to_cord(data.iter), ast_to_cord(data.body)) T(While, "(condition=%r, body=%r)", ast_to_cord(data.condition), ast_to_cord(data.body)) T(If, "(condition=%r, body=%r, else=%r)", ast_to_cord(data.condition), ast_to_cord(data.body), ast_to_cord(data.else_body)) + T(When, "(subject=%r, clauses=%r, else=%r)", ast_to_cord(data.subject), when_clauses_to_cord(data.clauses), ast_to_cord(data.else_body)) T(Reduction, "(iter=%r, combination=%r, fallback=%r)", ast_to_cord(data.iter), ast_to_cord(data.combination), ast_to_cord(data.fallback)) T(Skip, "(%s)", data.target) T(Stop, "(%s)", data.target) diff --git a/ast.h b/ast.h index 4db7b38..b3f3389 100644 --- a/ast.h +++ b/ast.h @@ -24,6 +24,11 @@ typedef struct ast_list_s { struct ast_list_s *next; } ast_list_t; +typedef struct when_clause_s { + ast_t *var, *tag_name, *body; + struct when_clause_s *next; +} when_clause_t; + typedef struct arg_ast_s { const char *name; type_ast_t *type; @@ -94,7 +99,7 @@ typedef enum { FunctionDef, Lambda, FunctionCall, KeywordArg, Block, - For, While, If, + For, While, If, When, Reduction, Skip, Stop, Pass, Return, @@ -199,6 +204,11 @@ struct ast_s { struct { ast_t *condition, *body, *else_body; } If; + struct { + ast_t *subject; + when_clause_t *clauses; + ast_t *else_body; + } When; struct { ast_t *iter, *combination, *fallback; } Reduction; diff --git a/compile.c b/compile.c index ed063ca..98de661 100644 --- a/compile.c +++ b/compile.c @@ -52,7 +52,7 @@ CORD compile_statement(env_t *env, ast_t *ast) { CORD stmt; switch (ast->tag) { - case If: case For: case While: case FunctionDef: case Return: case StructDef: case EnumDef: + case If: case When: case For: case While: case FunctionDef: case Return: case StructDef: case EnumDef: case Declare: case Assign: case UpdateAssign: case DocTest: stmt = compile(env, ast); break; @@ -462,8 +462,8 @@ CORD compile(env_t *env, ast_t *ast) code = CORD_all(code, "NULL"); for (ast_list_t *entry = table->entries; entry; entry = entry->next) { - auto entry = Match(entry->ast, TableEntry); - code = CORD_all(code, ",\n\t{", compile(env, entry->key), ", ", compile(env, entry->value), "}"); + auto e = Match(entry->ast, TableEntry); + code = CORD_all(code, ",\n\t{", compile(env, e->key), ", ", compile(env, e->value), "}"); } return CORD_cat(code, ")"); @@ -516,7 +516,6 @@ CORD compile(env_t *env, ast_t *ast) auto kwarg = Match(ast, KeywordArg); return CORD_asprintf(".%s=%r", kwarg->name, compile(env, kwarg->arg)); } - // KeywordArg, case If: { auto if_ = Match(ast, If); CORD code; @@ -525,6 +524,39 @@ CORD compile(env_t *env, ast_t *ast) CORD_sprintf(&code, "%r\nelse %r", code, compile(env, if_->else_body)); return code; } + case When: { + auto when = Match(ast, When); + type_t *subject_t = get_type(env, when->subject); + auto enum_t = Match(subject_t, EnumType); + CORD code = CORD_all("{ ", compile_type(subject_t), " $subject = ", compile(env, when->subject), ";\n" + "switch ($subject.$tag) {"); + type_t *result_t = get_type(env, ast); + (void)result_t; + for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { + const char *clause_tag_name = Match(clause->tag_name, Var)->name; + code = CORD_all(code, "case $tag$", enum_t->name, "$", clause_tag_name, ": {\n"); + type_t *tag_type = NULL; + for (tag_t *tag = enum_t->tags; tag; tag = tag->next) { + if (streq(tag->name, clause_tag_name)) { + tag_type = tag->type; + break; + } + } + assert(tag_type); + env_t *scope = env; + if (clause->var) { + code = CORD_all(code, compile_type(tag_type), " ", compile(env, clause->var), " = $subject.", clause_tag_name, ";\n"); + scope = fresh_scope(env); + set_binding(scope, Match(clause->var, Var)->name, new(binding_t, .type=tag_type)); + } + code = CORD_all(code, compile(scope, clause->body), "\nbreak;\n}\n"); + } + if (when->else_body) { + code = CORD_all(code, "default: {\n", compile(env, when->else_body), "\nbreak;\n}"); + } + code = CORD_all(code, "\n}\n}"); + return code; + } case While: { auto while_ = Match(ast, While); return CORD_asprintf("while (%r) %r", compile(env, while_->condition), compile(env, while_->body)); @@ -749,7 +781,7 @@ CORD compile(env_t *env, ast_t *ast) (int64_t)(test->expr->end - test->expr->file->text)); } else { return CORD_asprintf( - "{\n%r $expr = %r;\n" + "{ // Test:\n%r $expr = %r;\n" "__doctest(&$expr, %r, %r, %r, %ld, %ld);\n" "}", compile_type(expr_t), diff --git a/parse.c b/parse.c index 56e47e5..9418cc7 100644 --- a/parse.c +++ b/parse.c @@ -43,7 +43,7 @@ int op_tightness[] = { #define MAX_TIGHTNESS 9 static const char *keywords[] = { - "yes", "xor", "while", "use", "then", "struct", "stop", "skip", "return", + "yes", "xor", "while", "when", "use", "then", "struct", "stop", "skip", "return", "or", "not", "no", "mod1", "mod", "in", "if", "func", "for", "extern", "enum", "else", "do", "and", "_mix_", "_min_", "_max_", NULL, @@ -71,6 +71,7 @@ static arg_ast_t *parse_args(parse_ctx_t *ctx, const char **pos, bool allow_unna static PARSER(parse_for); static PARSER(parse_while); static PARSER(parse_if); +static PARSER(parse_when); static PARSER(parse_expr); static PARSER(parse_extended_expr); static PARSER(parse_term_no_suffix); @@ -803,13 +804,59 @@ PARSER(parse_if) { const char *tmp = pos; whitespace(&tmp); ast_t *else_body = NULL; + const char *else_start = pos; if (get_indent(ctx->file, tmp) == starting_indent && match_word(&tmp, "else")) { pos = tmp; - else_body = expect(ctx, start, &pos, parse_opt_indented_block, "I expected a body for this 'else'"); + else_body = expect(ctx, else_start, &pos, parse_opt_indented_block, "I expected a body for this 'else'"); } return NewAST(ctx->file, start, pos, If, .condition=condition, .body=body, .else_body=else_body); } +PARSER(parse_when) { + // when (is var : Tag [then] )* [else ] + const char *start = pos; + int64_t starting_indent = get_indent(ctx->file, pos); + + if (!match_word(&pos, "when")) + return NULL; + + ast_t *subject = optional(ctx, &pos, parse_declaration); + if (!subject) subject = expect(ctx, start, &pos, parse_expr, + "I expected to find an expression for this 'when'"); + + when_clause_t *clauses = NULL; + const char *tmp = pos; + whitespace(&tmp); + while (get_indent(ctx->file, tmp) == starting_indent && match_word(&tmp, "is")) { + pos = tmp; + spaces(&pos); + ast_t *tag_name, *var = expect(ctx, start, &pos, parse_var, "I expected a variable or tag name here"); + spaces(&pos); + if (match(&pos, ":")) { + spaces(&pos); + tag_name = optional(ctx, &pos, parse_var); + } else { + tag_name = var; + var = NULL; + } + + match_word(&pos, "then"); // optional + ast_t *body = expect(ctx, start, &pos, parse_opt_indented_block, "I expected a body for this 'when' clause"); + clauses = new(when_clause_t, .var=var, .tag_name=tag_name, .body=body, .next=clauses); + tmp = pos; + whitespace(&tmp); + } + REVERSE_LIST(clauses); + + ast_t *else_body = NULL; + const char *else_start = pos; + if (get_indent(ctx->file, tmp) == starting_indent && match_word(&tmp, "else")) { + pos = tmp; + else_body = expect(ctx, else_start, &pos, parse_opt_indented_block, "I expected a body for this 'else'"); + } + return NewAST(ctx->file, start, pos, When, .subject=subject, .clauses=clauses, .else_body=else_body); +} + PARSER(parse_for) { // for [k,] v in iter [] body const char *start = pos; @@ -1343,6 +1390,7 @@ PARSER(parse_extended_expr) { || (expr=optional(ctx, &pos, parse_for)) || (expr=optional(ctx, &pos, parse_while)) || (expr=optional(ctx, &pos, parse_if)) + || (expr=optional(ctx, &pos, parse_when)) ) return expr; diff --git a/typecheck.c b/typecheck.c index ee4d024..781b6ff 100644 --- a/typecheck.c +++ b/typecheck.c @@ -534,6 +534,86 @@ type_t *get_type(env_t *env, ast_t *ast) } } + case When: { + auto when = Match(ast, When); + type_t *subject_t = get_type(env, when->subject); + if (subject_t->tag != EnumType) + code_err(when->subject, "'when' statements are only for enum types, not %T", subject_t); + + tag_t * const tags = Match(subject_t, EnumType)->tags; + + typedef struct match_s { + const char *name; + type_t *type; + bool handled; + struct match_s *next; + } match_t; + match_t *matches = NULL; + for (tag_t *tag = tags; tag; tag = tag->next) + matches = new(match_t, .name=tag->name, .type=tag->type, .next=matches); + + type_t *overall_t = NULL; + for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { + const char *tag_name = Match(clause->tag_name, Var)->name; + type_t *tag_type = NULL; + for (match_t *m = matches; m; m = m->next) { + if (streq(m->name, tag_name)) { + if (m->handled) + code_err(clause->tag_name, "This tag was already handled earlier"); + m->handled = true; + tag_type = m->type; + break; + } + } + + if (!tag_type) + code_err(clause->tag_name, "This is not a valid tag for the type %T", subject_t); + + env_t *scope = env; + if (clause->var) { + scope = fresh_scope(scope); + set_binding(scope, Match(clause->var, Var)->name, new(binding_t, .type=tag_type)); + } + type_t *clause_type = get_type(scope, clause->body); + type_t *merged = type_or_type(overall_t, clause_type); + if (!merged) + code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T", + clause_type, overall_t); + overall_t = merged; + } + + if (when->else_body) { + bool any_unhandled = false; + for (match_t *m = matches; m; m = m->next) { + if (!m->handled) { + any_unhandled = true; + break; + } + } + if (!any_unhandled) + code_err(when->else_body, "This 'else' block will never run because every tag is handled"); + + type_t *else_t = get_type(env, when->else_body); + type_t *merged = type_or_type(overall_t, else_t); + if (!merged) + code_err(when->else_body, + "I was expecting this block to have a %T value (based on earlier clauses), but it actually has a %T value.", + overall_t, else_t); + // return merged; + return Type(VoidType); + } else { + CORD unhandled = CORD_EMPTY; + for (match_t *m = matches; m; m = m->next) { + if (!m->handled) + unhandled = unhandled ? CORD_all(unhandled, ", ", m->name) : m->name; + } + if (unhandled) + code_err(ast, "This 'while' statement doesn't handle the tag(s): %s", CORD_to_const_char_star(unhandled)); + // return overall_t; + return Type(VoidType); + } + } + case While: case For: return Type(VoidType); case Unknown: code_err(ast, "I can't figure out the type of: %W", ast);