Implement 'when' statement for matching on enums
This commit is contained in:
parent
d915c5f5a2
commit
2ecd8e11fd
1
Makefile
1
Makefile
@ -1,4 +1,3 @@
|
||||
CC=gcc
|
||||
PREFIX=/usr/local
|
||||
VERSION=0.12.1
|
||||
CCONFIG=-std=c11 -Werror -D_XOPEN_SOURCE=700 -D_POSIX_C_SOURCE=200809L -fPIC -ftrapv -fvisibility=hidden \
|
||||
|
18
ast.c
18
ast.c
@ -20,6 +20,7 @@ static CORD ast_to_cord(ast_t *ast);
|
||||
static CORD ast_list_to_cord(ast_list_t *asts);
|
||||
static CORD type_ast_to_cord(type_ast_t *t);
|
||||
static CORD arg_list_to_cord(arg_ast_t *args);
|
||||
static CORD when_clauses_to_cord(when_clause_t *clauses);
|
||||
static CORD tags_to_cord(tag_ast_t *tags);
|
||||
|
||||
#define TO_CORD(x) _Generic(x, \
|
||||
@ -60,8 +61,17 @@ CORD arg_list_to_cord(arg_ast_t *args) {
|
||||
CORD_sprintf(&c, "%r=%s", c, ast_to_cord(args->default_val));
|
||||
if (args->next) c = CORD_cat(c, ", ");
|
||||
}
|
||||
c = CORD_cat(c, ")");
|
||||
return c;
|
||||
return CORD_cat(c, ")");
|
||||
}
|
||||
|
||||
CORD when_clauses_to_cord(when_clause_t *clauses) {
|
||||
CORD c = "Clauses(";
|
||||
for (; clauses; clauses = clauses->next) {
|
||||
if (clauses->var) c = CORD_all(c, ast_to_cord(clauses->var), ":");
|
||||
c = CORD_all(c, ast_to_cord(clauses->tag_name), "=", ast_to_cord(clauses->body));
|
||||
if (clauses->next) c = CORD_cat(c, ", ");
|
||||
}
|
||||
return CORD_cat(c, ")");
|
||||
}
|
||||
|
||||
CORD tags_to_cord(tag_ast_t *tags) {
|
||||
@ -72,8 +82,7 @@ CORD tags_to_cord(tag_ast_t *tags) {
|
||||
CORD_sprintf(&c, "%r(%r)=%ld", c, arg_list_to_cord(tags->fields), tags->value);
|
||||
if (tags->next) c = CORD_cat(c, ", ");
|
||||
}
|
||||
c = CORD_cat(c, ")");
|
||||
return c;
|
||||
return CORD_cat(c, ")");
|
||||
}
|
||||
|
||||
CORD ast_to_cord(ast_t *ast)
|
||||
@ -117,6 +126,7 @@ CORD ast_to_cord(ast_t *ast)
|
||||
ast_to_cord(data.iter), ast_to_cord(data.body))
|
||||
T(While, "(condition=%r, body=%r)", ast_to_cord(data.condition), ast_to_cord(data.body))
|
||||
T(If, "(condition=%r, body=%r, else=%r)", ast_to_cord(data.condition), ast_to_cord(data.body), ast_to_cord(data.else_body))
|
||||
T(When, "(subject=%r, clauses=%r, else=%r)", ast_to_cord(data.subject), when_clauses_to_cord(data.clauses), ast_to_cord(data.else_body))
|
||||
T(Reduction, "(iter=%r, combination=%r, fallback=%r)", ast_to_cord(data.iter), ast_to_cord(data.combination), ast_to_cord(data.fallback))
|
||||
T(Skip, "(%s)", data.target)
|
||||
T(Stop, "(%s)", data.target)
|
||||
|
12
ast.h
12
ast.h
@ -24,6 +24,11 @@ typedef struct ast_list_s {
|
||||
struct ast_list_s *next;
|
||||
} ast_list_t;
|
||||
|
||||
typedef struct when_clause_s {
|
||||
ast_t *var, *tag_name, *body;
|
||||
struct when_clause_s *next;
|
||||
} when_clause_t;
|
||||
|
||||
typedef struct arg_ast_s {
|
||||
const char *name;
|
||||
type_ast_t *type;
|
||||
@ -94,7 +99,7 @@ typedef enum {
|
||||
FunctionDef, Lambda,
|
||||
FunctionCall, KeywordArg,
|
||||
Block,
|
||||
For, While, If,
|
||||
For, While, If, When,
|
||||
Reduction,
|
||||
Skip, Stop, Pass,
|
||||
Return,
|
||||
@ -199,6 +204,11 @@ struct ast_s {
|
||||
struct {
|
||||
ast_t *condition, *body, *else_body;
|
||||
} If;
|
||||
struct {
|
||||
ast_t *subject;
|
||||
when_clause_t *clauses;
|
||||
ast_t *else_body;
|
||||
} When;
|
||||
struct {
|
||||
ast_t *iter, *combination, *fallback;
|
||||
} Reduction;
|
||||
|
42
compile.c
42
compile.c
@ -52,7 +52,7 @@ CORD compile_statement(env_t *env, ast_t *ast)
|
||||
{
|
||||
CORD stmt;
|
||||
switch (ast->tag) {
|
||||
case If: case For: case While: case FunctionDef: case Return: case StructDef: case EnumDef:
|
||||
case If: case When: case For: case While: case FunctionDef: case Return: case StructDef: case EnumDef:
|
||||
case Declare: case Assign: case UpdateAssign: case DocTest:
|
||||
stmt = compile(env, ast);
|
||||
break;
|
||||
@ -462,8 +462,8 @@ CORD compile(env_t *env, ast_t *ast)
|
||||
code = CORD_all(code, "NULL");
|
||||
|
||||
for (ast_list_t *entry = table->entries; entry; entry = entry->next) {
|
||||
auto entry = Match(entry->ast, TableEntry);
|
||||
code = CORD_all(code, ",\n\t{", compile(env, entry->key), ", ", compile(env, entry->value), "}");
|
||||
auto e = Match(entry->ast, TableEntry);
|
||||
code = CORD_all(code, ",\n\t{", compile(env, e->key), ", ", compile(env, e->value), "}");
|
||||
}
|
||||
return CORD_cat(code, ")");
|
||||
|
||||
@ -516,7 +516,6 @@ CORD compile(env_t *env, ast_t *ast)
|
||||
auto kwarg = Match(ast, KeywordArg);
|
||||
return CORD_asprintf(".%s=%r", kwarg->name, compile(env, kwarg->arg));
|
||||
}
|
||||
// KeywordArg,
|
||||
case If: {
|
||||
auto if_ = Match(ast, If);
|
||||
CORD code;
|
||||
@ -525,6 +524,39 @@ CORD compile(env_t *env, ast_t *ast)
|
||||
CORD_sprintf(&code, "%r\nelse %r", code, compile(env, if_->else_body));
|
||||
return code;
|
||||
}
|
||||
case When: {
|
||||
auto when = Match(ast, When);
|
||||
type_t *subject_t = get_type(env, when->subject);
|
||||
auto enum_t = Match(subject_t, EnumType);
|
||||
CORD code = CORD_all("{ ", compile_type(subject_t), " $subject = ", compile(env, when->subject), ";\n"
|
||||
"switch ($subject.$tag) {");
|
||||
type_t *result_t = get_type(env, ast);
|
||||
(void)result_t;
|
||||
for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
|
||||
const char *clause_tag_name = Match(clause->tag_name, Var)->name;
|
||||
code = CORD_all(code, "case $tag$", enum_t->name, "$", clause_tag_name, ": {\n");
|
||||
type_t *tag_type = NULL;
|
||||
for (tag_t *tag = enum_t->tags; tag; tag = tag->next) {
|
||||
if (streq(tag->name, clause_tag_name)) {
|
||||
tag_type = tag->type;
|
||||
break;
|
||||
}
|
||||
}
|
||||
assert(tag_type);
|
||||
env_t *scope = env;
|
||||
if (clause->var) {
|
||||
code = CORD_all(code, compile_type(tag_type), " ", compile(env, clause->var), " = $subject.", clause_tag_name, ";\n");
|
||||
scope = fresh_scope(env);
|
||||
set_binding(scope, Match(clause->var, Var)->name, new(binding_t, .type=tag_type));
|
||||
}
|
||||
code = CORD_all(code, compile(scope, clause->body), "\nbreak;\n}\n");
|
||||
}
|
||||
if (when->else_body) {
|
||||
code = CORD_all(code, "default: {\n", compile(env, when->else_body), "\nbreak;\n}");
|
||||
}
|
||||
code = CORD_all(code, "\n}\n}");
|
||||
return code;
|
||||
}
|
||||
case While: {
|
||||
auto while_ = Match(ast, While);
|
||||
return CORD_asprintf("while (%r) %r", compile(env, while_->condition), compile(env, while_->body));
|
||||
@ -749,7 +781,7 @@ CORD compile(env_t *env, ast_t *ast)
|
||||
(int64_t)(test->expr->end - test->expr->file->text));
|
||||
} else {
|
||||
return CORD_asprintf(
|
||||
"{\n%r $expr = %r;\n"
|
||||
"{ // Test:\n%r $expr = %r;\n"
|
||||
"__doctest(&$expr, %r, %r, %r, %ld, %ld);\n"
|
||||
"}",
|
||||
compile_type(expr_t),
|
||||
|
52
parse.c
52
parse.c
@ -43,7 +43,7 @@ int op_tightness[] = {
|
||||
#define MAX_TIGHTNESS 9
|
||||
|
||||
static const char *keywords[] = {
|
||||
"yes", "xor", "while", "use", "then", "struct", "stop", "skip", "return",
|
||||
"yes", "xor", "while", "when", "use", "then", "struct", "stop", "skip", "return",
|
||||
"or", "not", "no", "mod1", "mod", "in", "if", "func", "for", "extern",
|
||||
"enum", "else", "do", "and", "_mix_", "_min_", "_max_",
|
||||
NULL,
|
||||
@ -71,6 +71,7 @@ static arg_ast_t *parse_args(parse_ctx_t *ctx, const char **pos, bool allow_unna
|
||||
static PARSER(parse_for);
|
||||
static PARSER(parse_while);
|
||||
static PARSER(parse_if);
|
||||
static PARSER(parse_when);
|
||||
static PARSER(parse_expr);
|
||||
static PARSER(parse_extended_expr);
|
||||
static PARSER(parse_term_no_suffix);
|
||||
@ -803,13 +804,59 @@ PARSER(parse_if) {
|
||||
const char *tmp = pos;
|
||||
whitespace(&tmp);
|
||||
ast_t *else_body = NULL;
|
||||
const char *else_start = pos;
|
||||
if (get_indent(ctx->file, tmp) == starting_indent && match_word(&tmp, "else")) {
|
||||
pos = tmp;
|
||||
else_body = expect(ctx, start, &pos, parse_opt_indented_block, "I expected a body for this 'else'");
|
||||
else_body = expect(ctx, else_start, &pos, parse_opt_indented_block, "I expected a body for this 'else'");
|
||||
}
|
||||
return NewAST(ctx->file, start, pos, If, .condition=condition, .body=body, .else_body=else_body);
|
||||
}
|
||||
|
||||
PARSER(parse_when) {
|
||||
// when <expr> (is var : Tag [then] <body>)* [else <body>]
|
||||
const char *start = pos;
|
||||
int64_t starting_indent = get_indent(ctx->file, pos);
|
||||
|
||||
if (!match_word(&pos, "when"))
|
||||
return NULL;
|
||||
|
||||
ast_t *subject = optional(ctx, &pos, parse_declaration);
|
||||
if (!subject) subject = expect(ctx, start, &pos, parse_expr,
|
||||
"I expected to find an expression for this 'when'");
|
||||
|
||||
when_clause_t *clauses = NULL;
|
||||
const char *tmp = pos;
|
||||
whitespace(&tmp);
|
||||
while (get_indent(ctx->file, tmp) == starting_indent && match_word(&tmp, "is")) {
|
||||
pos = tmp;
|
||||
spaces(&pos);
|
||||
ast_t *tag_name, *var = expect(ctx, start, &pos, parse_var, "I expected a variable or tag name here");
|
||||
spaces(&pos);
|
||||
if (match(&pos, ":")) {
|
||||
spaces(&pos);
|
||||
tag_name = optional(ctx, &pos, parse_var);
|
||||
} else {
|
||||
tag_name = var;
|
||||
var = NULL;
|
||||
}
|
||||
|
||||
match_word(&pos, "then"); // optional
|
||||
ast_t *body = expect(ctx, start, &pos, parse_opt_indented_block, "I expected a body for this 'when' clause");
|
||||
clauses = new(when_clause_t, .var=var, .tag_name=tag_name, .body=body, .next=clauses);
|
||||
tmp = pos;
|
||||
whitespace(&tmp);
|
||||
}
|
||||
REVERSE_LIST(clauses);
|
||||
|
||||
ast_t *else_body = NULL;
|
||||
const char *else_start = pos;
|
||||
if (get_indent(ctx->file, tmp) == starting_indent && match_word(&tmp, "else")) {
|
||||
pos = tmp;
|
||||
else_body = expect(ctx, else_start, &pos, parse_opt_indented_block, "I expected a body for this 'else'");
|
||||
}
|
||||
return NewAST(ctx->file, start, pos, When, .subject=subject, .clauses=clauses, .else_body=else_body);
|
||||
}
|
||||
|
||||
PARSER(parse_for) {
|
||||
// for [k,] v in iter [<indent>] body
|
||||
const char *start = pos;
|
||||
@ -1343,6 +1390,7 @@ PARSER(parse_extended_expr) {
|
||||
|| (expr=optional(ctx, &pos, parse_for))
|
||||
|| (expr=optional(ctx, &pos, parse_while))
|
||||
|| (expr=optional(ctx, &pos, parse_if))
|
||||
|| (expr=optional(ctx, &pos, parse_when))
|
||||
)
|
||||
return expr;
|
||||
|
||||
|
80
typecheck.c
80
typecheck.c
@ -534,6 +534,86 @@ type_t *get_type(env_t *env, ast_t *ast)
|
||||
}
|
||||
}
|
||||
|
||||
case When: {
|
||||
auto when = Match(ast, When);
|
||||
type_t *subject_t = get_type(env, when->subject);
|
||||
if (subject_t->tag != EnumType)
|
||||
code_err(when->subject, "'when' statements are only for enum types, not %T", subject_t);
|
||||
|
||||
tag_t * const tags = Match(subject_t, EnumType)->tags;
|
||||
|
||||
typedef struct match_s {
|
||||
const char *name;
|
||||
type_t *type;
|
||||
bool handled;
|
||||
struct match_s *next;
|
||||
} match_t;
|
||||
match_t *matches = NULL;
|
||||
for (tag_t *tag = tags; tag; tag = tag->next)
|
||||
matches = new(match_t, .name=tag->name, .type=tag->type, .next=matches);
|
||||
|
||||
type_t *overall_t = NULL;
|
||||
for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
|
||||
const char *tag_name = Match(clause->tag_name, Var)->name;
|
||||
type_t *tag_type = NULL;
|
||||
for (match_t *m = matches; m; m = m->next) {
|
||||
if (streq(m->name, tag_name)) {
|
||||
if (m->handled)
|
||||
code_err(clause->tag_name, "This tag was already handled earlier");
|
||||
m->handled = true;
|
||||
tag_type = m->type;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!tag_type)
|
||||
code_err(clause->tag_name, "This is not a valid tag for the type %T", subject_t);
|
||||
|
||||
env_t *scope = env;
|
||||
if (clause->var) {
|
||||
scope = fresh_scope(scope);
|
||||
set_binding(scope, Match(clause->var, Var)->name, new(binding_t, .type=tag_type));
|
||||
}
|
||||
type_t *clause_type = get_type(scope, clause->body);
|
||||
type_t *merged = type_or_type(overall_t, clause_type);
|
||||
if (!merged)
|
||||
code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T",
|
||||
clause_type, overall_t);
|
||||
overall_t = merged;
|
||||
}
|
||||
|
||||
if (when->else_body) {
|
||||
bool any_unhandled = false;
|
||||
for (match_t *m = matches; m; m = m->next) {
|
||||
if (!m->handled) {
|
||||
any_unhandled = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!any_unhandled)
|
||||
code_err(when->else_body, "This 'else' block will never run because every tag is handled");
|
||||
|
||||
type_t *else_t = get_type(env, when->else_body);
|
||||
type_t *merged = type_or_type(overall_t, else_t);
|
||||
if (!merged)
|
||||
code_err(when->else_body,
|
||||
"I was expecting this block to have a %T value (based on earlier clauses), but it actually has a %T value.",
|
||||
overall_t, else_t);
|
||||
// return merged;
|
||||
return Type(VoidType);
|
||||
} else {
|
||||
CORD unhandled = CORD_EMPTY;
|
||||
for (match_t *m = matches; m; m = m->next) {
|
||||
if (!m->handled)
|
||||
unhandled = unhandled ? CORD_all(unhandled, ", ", m->name) : m->name;
|
||||
}
|
||||
if (unhandled)
|
||||
code_err(ast, "This 'while' statement doesn't handle the tag(s): %s", CORD_to_const_char_star(unhandled));
|
||||
// return overall_t;
|
||||
return Type(VoidType);
|
||||
}
|
||||
}
|
||||
|
||||
case While: case For: return Type(VoidType);
|
||||
|
||||
case Unknown: code_err(ast, "I can't figure out the type of: %W", ast);
|
||||
|
Loading…
Reference in New Issue
Block a user