Implement 'when' statement for matching on enums

This commit is contained in:
Bruce Hill 2024-02-22 12:45:12 -05:00
parent d915c5f5a2
commit 2ecd8e11fd
6 changed files with 192 additions and 13 deletions

View File

@ -1,4 +1,3 @@
CC=gcc
PREFIX=/usr/local
VERSION=0.12.1
CCONFIG=-std=c11 -Werror -D_XOPEN_SOURCE=700 -D_POSIX_C_SOURCE=200809L -fPIC -ftrapv -fvisibility=hidden \

18
ast.c
View File

@ -20,6 +20,7 @@ static CORD ast_to_cord(ast_t *ast);
static CORD ast_list_to_cord(ast_list_t *asts);
static CORD type_ast_to_cord(type_ast_t *t);
static CORD arg_list_to_cord(arg_ast_t *args);
static CORD when_clauses_to_cord(when_clause_t *clauses);
static CORD tags_to_cord(tag_ast_t *tags);
#define TO_CORD(x) _Generic(x, \
@ -60,8 +61,17 @@ CORD arg_list_to_cord(arg_ast_t *args) {
CORD_sprintf(&c, "%r=%s", c, ast_to_cord(args->default_val));
if (args->next) c = CORD_cat(c, ", ");
}
c = CORD_cat(c, ")");
return c;
return CORD_cat(c, ")");
}
CORD when_clauses_to_cord(when_clause_t *clauses) {
CORD c = "Clauses(";
for (; clauses; clauses = clauses->next) {
if (clauses->var) c = CORD_all(c, ast_to_cord(clauses->var), ":");
c = CORD_all(c, ast_to_cord(clauses->tag_name), "=", ast_to_cord(clauses->body));
if (clauses->next) c = CORD_cat(c, ", ");
}
return CORD_cat(c, ")");
}
CORD tags_to_cord(tag_ast_t *tags) {
@ -72,8 +82,7 @@ CORD tags_to_cord(tag_ast_t *tags) {
CORD_sprintf(&c, "%r(%r)=%ld", c, arg_list_to_cord(tags->fields), tags->value);
if (tags->next) c = CORD_cat(c, ", ");
}
c = CORD_cat(c, ")");
return c;
return CORD_cat(c, ")");
}
CORD ast_to_cord(ast_t *ast)
@ -117,6 +126,7 @@ CORD ast_to_cord(ast_t *ast)
ast_to_cord(data.iter), ast_to_cord(data.body))
T(While, "(condition=%r, body=%r)", ast_to_cord(data.condition), ast_to_cord(data.body))
T(If, "(condition=%r, body=%r, else=%r)", ast_to_cord(data.condition), ast_to_cord(data.body), ast_to_cord(data.else_body))
T(When, "(subject=%r, clauses=%r, else=%r)", ast_to_cord(data.subject), when_clauses_to_cord(data.clauses), ast_to_cord(data.else_body))
T(Reduction, "(iter=%r, combination=%r, fallback=%r)", ast_to_cord(data.iter), ast_to_cord(data.combination), ast_to_cord(data.fallback))
T(Skip, "(%s)", data.target)
T(Stop, "(%s)", data.target)

12
ast.h
View File

@ -24,6 +24,11 @@ typedef struct ast_list_s {
struct ast_list_s *next;
} ast_list_t;
typedef struct when_clause_s {
ast_t *var, *tag_name, *body;
struct when_clause_s *next;
} when_clause_t;
typedef struct arg_ast_s {
const char *name;
type_ast_t *type;
@ -94,7 +99,7 @@ typedef enum {
FunctionDef, Lambda,
FunctionCall, KeywordArg,
Block,
For, While, If,
For, While, If, When,
Reduction,
Skip, Stop, Pass,
Return,
@ -199,6 +204,11 @@ struct ast_s {
struct {
ast_t *condition, *body, *else_body;
} If;
struct {
ast_t *subject;
when_clause_t *clauses;
ast_t *else_body;
} When;
struct {
ast_t *iter, *combination, *fallback;
} Reduction;

View File

@ -52,7 +52,7 @@ CORD compile_statement(env_t *env, ast_t *ast)
{
CORD stmt;
switch (ast->tag) {
case If: case For: case While: case FunctionDef: case Return: case StructDef: case EnumDef:
case If: case When: case For: case While: case FunctionDef: case Return: case StructDef: case EnumDef:
case Declare: case Assign: case UpdateAssign: case DocTest:
stmt = compile(env, ast);
break;
@ -462,8 +462,8 @@ CORD compile(env_t *env, ast_t *ast)
code = CORD_all(code, "NULL");
for (ast_list_t *entry = table->entries; entry; entry = entry->next) {
auto entry = Match(entry->ast, TableEntry);
code = CORD_all(code, ",\n\t{", compile(env, entry->key), ", ", compile(env, entry->value), "}");
auto e = Match(entry->ast, TableEntry);
code = CORD_all(code, ",\n\t{", compile(env, e->key), ", ", compile(env, e->value), "}");
}
return CORD_cat(code, ")");
@ -516,7 +516,6 @@ CORD compile(env_t *env, ast_t *ast)
auto kwarg = Match(ast, KeywordArg);
return CORD_asprintf(".%s=%r", kwarg->name, compile(env, kwarg->arg));
}
// KeywordArg,
case If: {
auto if_ = Match(ast, If);
CORD code;
@ -525,6 +524,39 @@ CORD compile(env_t *env, ast_t *ast)
CORD_sprintf(&code, "%r\nelse %r", code, compile(env, if_->else_body));
return code;
}
case When: {
auto when = Match(ast, When);
type_t *subject_t = get_type(env, when->subject);
auto enum_t = Match(subject_t, EnumType);
CORD code = CORD_all("{ ", compile_type(subject_t), " $subject = ", compile(env, when->subject), ";\n"
"switch ($subject.$tag) {");
type_t *result_t = get_type(env, ast);
(void)result_t;
for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
const char *clause_tag_name = Match(clause->tag_name, Var)->name;
code = CORD_all(code, "case $tag$", enum_t->name, "$", clause_tag_name, ": {\n");
type_t *tag_type = NULL;
for (tag_t *tag = enum_t->tags; tag; tag = tag->next) {
if (streq(tag->name, clause_tag_name)) {
tag_type = tag->type;
break;
}
}
assert(tag_type);
env_t *scope = env;
if (clause->var) {
code = CORD_all(code, compile_type(tag_type), " ", compile(env, clause->var), " = $subject.", clause_tag_name, ";\n");
scope = fresh_scope(env);
set_binding(scope, Match(clause->var, Var)->name, new(binding_t, .type=tag_type));
}
code = CORD_all(code, compile(scope, clause->body), "\nbreak;\n}\n");
}
if (when->else_body) {
code = CORD_all(code, "default: {\n", compile(env, when->else_body), "\nbreak;\n}");
}
code = CORD_all(code, "\n}\n}");
return code;
}
case While: {
auto while_ = Match(ast, While);
return CORD_asprintf("while (%r) %r", compile(env, while_->condition), compile(env, while_->body));
@ -749,7 +781,7 @@ CORD compile(env_t *env, ast_t *ast)
(int64_t)(test->expr->end - test->expr->file->text));
} else {
return CORD_asprintf(
"{\n%r $expr = %r;\n"
"{ // Test:\n%r $expr = %r;\n"
"__doctest(&$expr, %r, %r, %r, %ld, %ld);\n"
"}",
compile_type(expr_t),

52
parse.c
View File

@ -43,7 +43,7 @@ int op_tightness[] = {
#define MAX_TIGHTNESS 9
static const char *keywords[] = {
"yes", "xor", "while", "use", "then", "struct", "stop", "skip", "return",
"yes", "xor", "while", "when", "use", "then", "struct", "stop", "skip", "return",
"or", "not", "no", "mod1", "mod", "in", "if", "func", "for", "extern",
"enum", "else", "do", "and", "_mix_", "_min_", "_max_",
NULL,
@ -71,6 +71,7 @@ static arg_ast_t *parse_args(parse_ctx_t *ctx, const char **pos, bool allow_unna
static PARSER(parse_for);
static PARSER(parse_while);
static PARSER(parse_if);
static PARSER(parse_when);
static PARSER(parse_expr);
static PARSER(parse_extended_expr);
static PARSER(parse_term_no_suffix);
@ -803,13 +804,59 @@ PARSER(parse_if) {
const char *tmp = pos;
whitespace(&tmp);
ast_t *else_body = NULL;
const char *else_start = pos;
if (get_indent(ctx->file, tmp) == starting_indent && match_word(&tmp, "else")) {
pos = tmp;
else_body = expect(ctx, start, &pos, parse_opt_indented_block, "I expected a body for this 'else'");
else_body = expect(ctx, else_start, &pos, parse_opt_indented_block, "I expected a body for this 'else'");
}
return NewAST(ctx->file, start, pos, If, .condition=condition, .body=body, .else_body=else_body);
}
PARSER(parse_when) {
// when <expr> (is var : Tag [then] <body>)* [else <body>]
const char *start = pos;
int64_t starting_indent = get_indent(ctx->file, pos);
if (!match_word(&pos, "when"))
return NULL;
ast_t *subject = optional(ctx, &pos, parse_declaration);
if (!subject) subject = expect(ctx, start, &pos, parse_expr,
"I expected to find an expression for this 'when'");
when_clause_t *clauses = NULL;
const char *tmp = pos;
whitespace(&tmp);
while (get_indent(ctx->file, tmp) == starting_indent && match_word(&tmp, "is")) {
pos = tmp;
spaces(&pos);
ast_t *tag_name, *var = expect(ctx, start, &pos, parse_var, "I expected a variable or tag name here");
spaces(&pos);
if (match(&pos, ":")) {
spaces(&pos);
tag_name = optional(ctx, &pos, parse_var);
} else {
tag_name = var;
var = NULL;
}
match_word(&pos, "then"); // optional
ast_t *body = expect(ctx, start, &pos, parse_opt_indented_block, "I expected a body for this 'when' clause");
clauses = new(when_clause_t, .var=var, .tag_name=tag_name, .body=body, .next=clauses);
tmp = pos;
whitespace(&tmp);
}
REVERSE_LIST(clauses);
ast_t *else_body = NULL;
const char *else_start = pos;
if (get_indent(ctx->file, tmp) == starting_indent && match_word(&tmp, "else")) {
pos = tmp;
else_body = expect(ctx, else_start, &pos, parse_opt_indented_block, "I expected a body for this 'else'");
}
return NewAST(ctx->file, start, pos, When, .subject=subject, .clauses=clauses, .else_body=else_body);
}
PARSER(parse_for) {
// for [k,] v in iter [<indent>] body
const char *start = pos;
@ -1343,6 +1390,7 @@ PARSER(parse_extended_expr) {
|| (expr=optional(ctx, &pos, parse_for))
|| (expr=optional(ctx, &pos, parse_while))
|| (expr=optional(ctx, &pos, parse_if))
|| (expr=optional(ctx, &pos, parse_when))
)
return expr;

View File

@ -534,6 +534,86 @@ type_t *get_type(env_t *env, ast_t *ast)
}
}
case When: {
auto when = Match(ast, When);
type_t *subject_t = get_type(env, when->subject);
if (subject_t->tag != EnumType)
code_err(when->subject, "'when' statements are only for enum types, not %T", subject_t);
tag_t * const tags = Match(subject_t, EnumType)->tags;
typedef struct match_s {
const char *name;
type_t *type;
bool handled;
struct match_s *next;
} match_t;
match_t *matches = NULL;
for (tag_t *tag = tags; tag; tag = tag->next)
matches = new(match_t, .name=tag->name, .type=tag->type, .next=matches);
type_t *overall_t = NULL;
for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
const char *tag_name = Match(clause->tag_name, Var)->name;
type_t *tag_type = NULL;
for (match_t *m = matches; m; m = m->next) {
if (streq(m->name, tag_name)) {
if (m->handled)
code_err(clause->tag_name, "This tag was already handled earlier");
m->handled = true;
tag_type = m->type;
break;
}
}
if (!tag_type)
code_err(clause->tag_name, "This is not a valid tag for the type %T", subject_t);
env_t *scope = env;
if (clause->var) {
scope = fresh_scope(scope);
set_binding(scope, Match(clause->var, Var)->name, new(binding_t, .type=tag_type));
}
type_t *clause_type = get_type(scope, clause->body);
type_t *merged = type_or_type(overall_t, clause_type);
if (!merged)
code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T",
clause_type, overall_t);
overall_t = merged;
}
if (when->else_body) {
bool any_unhandled = false;
for (match_t *m = matches; m; m = m->next) {
if (!m->handled) {
any_unhandled = true;
break;
}
}
if (!any_unhandled)
code_err(when->else_body, "This 'else' block will never run because every tag is handled");
type_t *else_t = get_type(env, when->else_body);
type_t *merged = type_or_type(overall_t, else_t);
if (!merged)
code_err(when->else_body,
"I was expecting this block to have a %T value (based on earlier clauses), but it actually has a %T value.",
overall_t, else_t);
// return merged;
return Type(VoidType);
} else {
CORD unhandled = CORD_EMPTY;
for (match_t *m = matches; m; m = m->next) {
if (!m->handled)
unhandled = unhandled ? CORD_all(unhandled, ", ", m->name) : m->name;
}
if (unhandled)
code_err(ast, "This 'while' statement doesn't handle the tag(s): %s", CORD_to_const_char_star(unhandled));
// return overall_t;
return Type(VoidType);
}
}
case While: case For: return Type(VoidType);
case Unknown: code_err(ast, "I can't figure out the type of: %W", ast);