Support 'when' for literal values with equality checking
This commit is contained in:
parent
09423f6d42
commit
73e559fbe4
2
ast.c
2
ast.c
@ -71,7 +71,7 @@ CORD arg_list_to_xml(arg_ast_t *args) {
|
||||
CORD when_clauses_to_xml(when_clause_t *clauses) {
|
||||
CORD c = CORD_EMPTY;
|
||||
for (; clauses; clauses = clauses->next) {
|
||||
c = CORD_all(c, "<case tag=\"", ast_to_xml(clauses->tag_name), "\">", ast_list_to_xml(clauses->args), ast_to_xml(clauses->body), "</case>");
|
||||
c = CORD_all(c, "<case>", ast_to_xml(clauses->pattern), ast_to_xml(clauses->body), "</case>");
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
4
ast.h
4
ast.h
@ -52,9 +52,7 @@ typedef struct arg_ast_s {
|
||||
} arg_ast_t;
|
||||
|
||||
typedef struct when_clause_s {
|
||||
ast_t *tag_name;
|
||||
ast_list_t *args;
|
||||
ast_t *body;
|
||||
ast_t *pattern, *body;
|
||||
struct when_clause_s *next;
|
||||
} when_clause_t;
|
||||
|
||||
|
120
compile.c
120
compile.c
@ -360,9 +360,27 @@ static void add_closed_vars(Table_t *closed_vars, env_t *enclosing_scope, env_t
|
||||
add_closed_vars(closed_vars, enclosing_scope, env, when->subject);
|
||||
type_t *subject_t = get_type(env, when->subject);
|
||||
|
||||
if (subject_t->tag != EnumType) {
|
||||
for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
|
||||
add_closed_vars(closed_vars, enclosing_scope, env, clause->pattern);
|
||||
add_closed_vars(closed_vars, enclosing_scope, env, clause->body);
|
||||
}
|
||||
|
||||
if (when->else_body)
|
||||
add_closed_vars(closed_vars, enclosing_scope, env, when->else_body);
|
||||
return;
|
||||
}
|
||||
|
||||
auto enum_t = Match(subject_t, EnumType);
|
||||
for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
|
||||
const char *clause_tag_name = Match(clause->tag_name, Var)->name;
|
||||
const char *clause_tag_name;
|
||||
if (clause->pattern->tag == Var)
|
||||
clause_tag_name = Match(clause->pattern, Var)->name;
|
||||
else if (clause->pattern->tag == FunctionCall && Match(clause->pattern, FunctionCall)->fn->tag == Var)
|
||||
clause_tag_name = Match(Match(clause->pattern, FunctionCall)->fn, Var)->name;
|
||||
else
|
||||
code_err(clause->pattern, "This is not a valid pattern for a %T enum", subject_t);
|
||||
|
||||
type_t *tag_type = NULL;
|
||||
for (tag_t *tag = enum_t->tags; tag; tag = tag->next) {
|
||||
if (streq(tag->name, clause_tag_name)) {
|
||||
@ -371,26 +389,7 @@ static void add_closed_vars(Table_t *closed_vars, env_t *enclosing_scope, env_t
|
||||
}
|
||||
}
|
||||
assert(tag_type);
|
||||
env_t *scope = env;
|
||||
|
||||
auto tag_struct = Match(tag_type, StructType);
|
||||
if (clause->args && !clause->args->next && tag_struct->fields && tag_struct->fields->next) {
|
||||
scope = fresh_scope(scope);
|
||||
set_binding(scope, Match(clause->args->ast, Var)->name, tag_type, CORD_EMPTY);
|
||||
} else if (clause->args) {
|
||||
scope = fresh_scope(scope);
|
||||
ast_list_t *var = clause->args;
|
||||
arg_t *field = tag_struct->fields;
|
||||
while (var || field) {
|
||||
if (!var)
|
||||
code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, clause_tag_name, field->name);
|
||||
if (!field)
|
||||
code_err(var->ast, "This is one more field than %T has", subject_t);
|
||||
set_binding(scope, Match(var->ast, Var)->name, field->type, CORD_EMPTY);
|
||||
var = var->next;
|
||||
field = field->next;
|
||||
}
|
||||
}
|
||||
env_t *scope = when_clause_scope(env, subject_t, clause);
|
||||
add_closed_vars(closed_vars, enclosing_scope, scope, clause->body);
|
||||
}
|
||||
if (when->else_body)
|
||||
@ -752,11 +751,44 @@ static CORD _compile_statement(env_t *env, ast_t *ast)
|
||||
auto when = Match(ast, When);
|
||||
type_t *subject_t = get_type(env, when->subject);
|
||||
|
||||
if (subject_t->tag != EnumType) {
|
||||
CORD prefix = CORD_EMPTY, suffix = CORD_EMPTY;
|
||||
ast_t *subject = when->subject;
|
||||
if (!is_idempotent(when->subject)) {
|
||||
prefix = CORD_all("{\n", compile_declaration(subject_t, "_when_subject"), " = ", compile(env, subject), ";\n");
|
||||
suffix = "}\n";
|
||||
subject = WrapAST(subject, InlineCCode, .type=subject_t, .code="_when_subject");
|
||||
}
|
||||
|
||||
CORD code = CORD_EMPTY;
|
||||
for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
|
||||
ast_t *comparison = WrapAST(clause->pattern, BinaryOp, .lhs=subject, .op=BINOP_EQ, .rhs=clause->pattern);
|
||||
if (code != CORD_EMPTY)
|
||||
code = CORD_all(code, "else ");
|
||||
code = CORD_all(code, "if (", compile(env, comparison), ")", compile_statement(env, clause->body));
|
||||
}
|
||||
if (when->else_body)
|
||||
code = CORD_all(code, "else ", compile_statement(env, when->else_body));
|
||||
code = CORD_all(prefix, code, suffix);
|
||||
return code;
|
||||
}
|
||||
|
||||
auto enum_t = Match(subject_t, EnumType);
|
||||
CORD code = CORD_all("{ ", compile_type(subject_t), " subject = ", compile(env, when->subject), ";\n"
|
||||
"switch (subject.tag) {");
|
||||
for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
|
||||
const char *clause_tag_name = Match(clause->tag_name, Var)->name;
|
||||
if (clause->pattern->tag == Var) {
|
||||
const char *clause_tag_name = Match(clause->pattern, Var)->name;
|
||||
code = CORD_all(code, "case ", namespace_prefix(enum_t->env, enum_t->env->namespace), "tag$", clause_tag_name, ": {\n",
|
||||
compile_statement(env, clause->body),
|
||||
"}\n");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (clause->pattern->tag != FunctionCall || Match(clause->pattern, FunctionCall)->fn->tag != Var)
|
||||
code_err(clause->pattern, "This is not a valid pattern for a %T enum type", subject_t);
|
||||
|
||||
const char *clause_tag_name = Match(Match(clause->pattern, FunctionCall)->fn, Var)->name;
|
||||
code = CORD_all(code, "case ", namespace_prefix(enum_t->env, enum_t->env->namespace), "tag$", clause_tag_name, ": {\n");
|
||||
type_t *tag_type = NULL;
|
||||
for (tag_t *tag = enum_t->tags; tag; tag = tag->next) {
|
||||
@ -769,22 +801,32 @@ static CORD _compile_statement(env_t *env, ast_t *ast)
|
||||
env_t *scope = env;
|
||||
|
||||
auto tag_struct = Match(tag_type, StructType);
|
||||
if (clause->args && !clause->args->next && tag_struct->fields && tag_struct->fields->next) {
|
||||
code = CORD_all(code, compile_declaration(tag_type, compile(env, clause->args->ast)), " = subject.$", clause_tag_name, ";\n");
|
||||
arg_ast_t *args = Match(clause->pattern, FunctionCall)->args;
|
||||
if (args && !args->next && tag_struct->fields && tag_struct->fields->next) {
|
||||
if (args->value->tag != Var)
|
||||
code_err(args->value, "This is not a valid variable to bind to");
|
||||
const char *var_name = Match(args->value, Var)->name;
|
||||
if (!streq(var_name, "_")) {
|
||||
code = CORD_all(code, compile_declaration(tag_type, compile(env, args->value)), " = subject.$", clause_tag_name, ";\n");
|
||||
scope = fresh_scope(scope);
|
||||
set_binding(scope, Match(args->value, Var)->name, tag_type, CORD_EMPTY);
|
||||
}
|
||||
} else if (args) {
|
||||
scope = fresh_scope(scope);
|
||||
set_binding(scope, Match(clause->args->ast, Var)->name, tag_type, CORD_EMPTY);
|
||||
} else if (clause->args) {
|
||||
scope = fresh_scope(scope);
|
||||
ast_list_t *var = clause->args;
|
||||
arg_t *field = tag_struct->fields;
|
||||
while (var || field) {
|
||||
if (!var)
|
||||
code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, clause_tag_name, field->name);
|
||||
for (arg_ast_t *arg = args; arg || field; arg = arg->next) {
|
||||
if (!arg)
|
||||
code_err(ast, "The field %T.%s.%s wasn't accounted for", subject_t, clause_tag_name, field->name);
|
||||
if (!field)
|
||||
code_err(var->ast, "This is one more field than %T has", subject_t);
|
||||
code = CORD_all(code, compile_declaration(field->type, compile(env, var->ast)), " = subject.$", clause_tag_name, ".$", field->name, ";\n");
|
||||
set_binding(scope, Match(var->ast, Var)->name, field->type, CORD_EMPTY);
|
||||
var = var->next;
|
||||
code_err(arg->value, "This is one more field than %T has", subject_t);
|
||||
if (arg->name)
|
||||
code_err(arg->value, "Named arguments are not currently supported");
|
||||
|
||||
const char *var_name = Match(arg->value, Var)->name;
|
||||
if (!streq(var_name, "_")) {
|
||||
code = CORD_all(code, compile_declaration(field->type, compile(env, arg->value)), " = subject.$", clause_tag_name, ".$", field->name, ";\n");
|
||||
set_binding(scope, Match(arg->value, Var)->name, field->type, CORD_EMPTY);
|
||||
}
|
||||
field = field->next;
|
||||
}
|
||||
}
|
||||
@ -1160,7 +1202,7 @@ static CORD _compile_statement(env_t *env, ast_t *ast)
|
||||
const char *target = Match(ast, Skip)->target;
|
||||
for (loop_ctx_t *ctx = env->loop_ctx; ctx; ctx = ctx->next) {
|
||||
bool matched = !target || CORD_cmp(target, ctx->loop_name) == 0;
|
||||
for (ast_list_t *var = ctx->loop_vars; var && !matched; var = var->next)
|
||||
for (ast_list_t *var = ctx->loop_vars; var && !matched; var = var ? var->next : NULL)
|
||||
matched = (CORD_cmp(target, Match(var->ast, Var)->name) == 0);
|
||||
|
||||
if (matched) {
|
||||
@ -1189,7 +1231,7 @@ static CORD _compile_statement(env_t *env, ast_t *ast)
|
||||
const char *target = Match(ast, Stop)->target;
|
||||
for (loop_ctx_t *ctx = env->loop_ctx; ctx; ctx = ctx->next) {
|
||||
bool matched = !target || CORD_cmp(target, ctx->loop_name) == 0;
|
||||
for (ast_list_t *var = ctx->loop_vars; var && !matched; var = var->next)
|
||||
for (ast_list_t *var = ctx->loop_vars; var && !matched; var = var ? var->next : var)
|
||||
matched = (CORD_cmp(target, Match(var->ast, Var)->name) == 0);
|
||||
|
||||
if (matched) {
|
||||
@ -3443,12 +3485,12 @@ CORD compile(env_t *env, ast_t *ast)
|
||||
for (when_clause_t *clause = original->clauses; clause; clause = clause->next) {
|
||||
type_t *clause_type = get_clause_type(env, subject_t, clause);
|
||||
if (clause_type->tag == AbortType || clause_type->tag == ReturnType) {
|
||||
new_clauses = new(when_clause_t, .tag_name=clause->tag_name, .args=clause->args, .body=clause->body, .next=new_clauses);
|
||||
new_clauses = new(when_clause_t, .pattern=clause->pattern, .body=clause->body, .next=new_clauses);
|
||||
} else {
|
||||
ast_t *assign = WrapAST(clause->body, Assign,
|
||||
.targets=new(ast_list_t, .ast=when_var),
|
||||
.values=new(ast_list_t, .ast=clause->body));
|
||||
new_clauses = new(when_clause_t, .tag_name=clause->tag_name, .args=clause->args, .body=assign, .next=new_clauses);
|
||||
new_clauses = new(when_clause_t, .pattern=clause->pattern, .body=assign, .next=new_clauses);
|
||||
}
|
||||
}
|
||||
REVERSE_LIST(new_clauses);
|
||||
|
30
parse.c
30
parse.c
@ -1115,37 +1115,13 @@ PARSER(parse_when) {
|
||||
while (get_indent(ctx, tmp) == starting_indent && match_word(&tmp, "is")) {
|
||||
pos = tmp;
|
||||
spaces(&pos);
|
||||
ast_t *tag_name;
|
||||
ast_list_t *args;
|
||||
if (match(&pos, "@")) {
|
||||
tag_name = NewAST(ctx->file, pos-1, pos, Var, .name="@");
|
||||
spaces(&pos);
|
||||
ast_t *arg = optional(ctx, &pos, parse_var);
|
||||
args = arg ? new(ast_list_t, .ast=arg) : NULL;
|
||||
} else {
|
||||
tag_name = expect(ctx, start, &pos, parse_var, "I expected a tag name here");
|
||||
spaces(&pos);
|
||||
args = NULL;
|
||||
if (match(&pos, "(")) {
|
||||
for (;;) {
|
||||
whitespace(&pos);
|
||||
ast_t *arg = optional(ctx, &pos, parse_var);
|
||||
if (!arg) break;
|
||||
args = new(ast_list_t, .ast=arg, .next=args);
|
||||
whitespace(&pos);
|
||||
if (!match(&pos, ",")) break;
|
||||
}
|
||||
whitespace(&pos);
|
||||
expect_closing(ctx, &pos, ")", "I was expecting a ')' to finish this pattern's arguments");
|
||||
REVERSE_LIST(args);
|
||||
}
|
||||
}
|
||||
|
||||
ast_t *pattern = expect(ctx, start, &pos, parse_expr, "I expected a pattern to match here");
|
||||
spaces(&pos);
|
||||
tmp = pos;
|
||||
if (!match(&tmp, ":"))
|
||||
parser_err(ctx, tmp, tmp, "I expected a colon ':' after this clause");
|
||||
ast_t *body = expect(ctx, start, &pos, parse_block, "I expected a body for this 'when' clause");
|
||||
clauses = new(when_clause_t, .tag_name=tag_name, .args=args, .body=body, .next=clauses);
|
||||
clauses = new(when_clause_t, .pattern=pattern, .body=body, .next=clauses);
|
||||
tmp = pos;
|
||||
whitespace(&tmp);
|
||||
}
|
||||
|
89
typecheck.c
89
typecheck.c
@ -484,13 +484,18 @@ type_t *get_method_type(env_t *env, ast_t *self, const char *name)
|
||||
return b->type;
|
||||
}
|
||||
|
||||
type_t *get_clause_type(env_t *env, type_t *subject_t, when_clause_t *clause)
|
||||
env_t *when_clause_scope(env_t *env, type_t *subject_t, when_clause_t *clause)
|
||||
{
|
||||
assert(subject_t->tag == EnumType);
|
||||
tag_t * const tags = Match(subject_t, EnumType)->tags;
|
||||
if (clause->pattern->tag == Var || subject_t->tag != EnumType)
|
||||
return env;
|
||||
|
||||
const char *tag_name = Match(clause->tag_name, Var)->name;
|
||||
if (clause->pattern->tag != FunctionCall || Match(clause->pattern, FunctionCall)->fn->tag != Var)
|
||||
code_err(clause->pattern, "I only support variables and constructors for pattern matching %T types in a 'when' block", subject_t);
|
||||
|
||||
auto fn = Match(clause->pattern, FunctionCall);
|
||||
const char *tag_name = Match(fn->fn, Var)->name;
|
||||
type_t *tag_type = NULL;
|
||||
tag_t * const tags = Match(subject_t, EnumType)->tags;
|
||||
for (tag_t *tag = tags; tag; tag = tag->next) {
|
||||
if (streq(tag->name, tag_name)) {
|
||||
tag_type = tag->type;
|
||||
@ -499,29 +504,38 @@ type_t *get_clause_type(env_t *env, type_t *subject_t, when_clause_t *clause)
|
||||
}
|
||||
|
||||
if (!tag_type)
|
||||
code_err(clause->tag_name, "There is no tag '%s' for the type %T", tag_name, subject_t);
|
||||
code_err(clause->pattern, "There is no tag '%s' for the type %T", tag_name, subject_t);
|
||||
|
||||
// Don't return early so we validate the tags
|
||||
if (!clause->args)
|
||||
return get_type(env, clause->body);
|
||||
if (!fn->args)
|
||||
return env;
|
||||
|
||||
env_t *scope = fresh_scope(env);
|
||||
auto tag_struct = Match(tag_type, StructType);
|
||||
if (!clause->args->next && tag_struct->fields && tag_struct->fields->next) {
|
||||
set_binding(scope, Match(clause->args->ast, Var)->name, tag_type, CORD_EMPTY);
|
||||
} else {
|
||||
ast_list_t *var = clause->args;
|
||||
arg_t *field = tag_struct->fields;
|
||||
while (var || field) {
|
||||
if (!var)
|
||||
code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, tag_name, field->name);
|
||||
if (!field)
|
||||
code_err(var->ast, "This is one more field than %T has", subject_t);
|
||||
set_binding(scope, Match(var->ast, Var)->name, field->type, CORD_EMPTY);
|
||||
var = var->next;
|
||||
field = field->next;
|
||||
}
|
||||
if (fn->args && !fn->args->next && tag_struct->fields && tag_struct->fields->next) {
|
||||
if (fn->args->value->tag != Var)
|
||||
code_err(fn->args->value, "I expected a variable here");
|
||||
set_binding(scope, Match(fn->args->value, Var)->name, tag_type, CORD_EMPTY);
|
||||
return scope;
|
||||
}
|
||||
|
||||
arg_t *field = tag_struct->fields;
|
||||
for (arg_ast_t *var = fn->args; var || field; var = var ? var->next : var) {
|
||||
if (!var)
|
||||
code_err(clause->pattern, "The field %T.%s.%s wasn't accounted for", subject_t, tag_name, field->name);
|
||||
if (!field)
|
||||
code_err(var->value, "This is one more field than %T has", subject_t);
|
||||
if (var->value->tag != Var)
|
||||
code_err(var->value, "I expected this to be a plain variable so I could bind it to a value");
|
||||
if (!streq(Match(var->value, Var)->name, "_"))
|
||||
set_binding(scope, Match(var->value, Var)->name, field->type, CORD_EMPTY);
|
||||
field = field->next;
|
||||
}
|
||||
return scope;
|
||||
}
|
||||
|
||||
type_t *get_clause_type(env_t *env, type_t *subject_t, when_clause_t *clause)
|
||||
{
|
||||
env_t *scope = when_clause_scope(env, subject_t, clause);
|
||||
return get_type(scope, clause->body);
|
||||
}
|
||||
|
||||
@ -1258,10 +1272,19 @@ type_t *get_type(env_t *env, ast_t *ast)
|
||||
case When: {
|
||||
auto when = Match(ast, When);
|
||||
type_t *subject_t = get_type(env, when->subject);
|
||||
type_t *overall_t = NULL;
|
||||
if (subject_t->tag != EnumType)
|
||||
code_err(when->subject, "'when' statements are only for enum types and optional pointers, not %T", subject_t);
|
||||
if (subject_t->tag != EnumType) {
|
||||
type_t *t = NULL;
|
||||
for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
|
||||
t = type_or_type(t, get_type(env, clause->body));
|
||||
}
|
||||
if (when->else_body)
|
||||
t = type_or_type(t, get_type(env, when->else_body));
|
||||
else if (t->tag != OptionalType)
|
||||
t = Type(OptionalType, .type=t);
|
||||
return t;
|
||||
}
|
||||
|
||||
type_t *overall_t = NULL;
|
||||
tag_t * const tags = Match(subject_t, EnumType)->tags;
|
||||
|
||||
typedef struct match_s {
|
||||
@ -1274,12 +1297,19 @@ type_t *get_type(env_t *env, ast_t *ast)
|
||||
matches = new(match_t, .tag=tag, .handled=false, .next=matches);
|
||||
|
||||
for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
|
||||
const char *tag_name = Match(clause->tag_name, Var)->name;
|
||||
const char *tag_name;
|
||||
if (clause->pattern->tag == Var)
|
||||
tag_name = Match(clause->pattern, Var)->name;
|
||||
else if (clause->pattern->tag == FunctionCall && Match(clause->pattern, FunctionCall)->fn->tag == Var)
|
||||
tag_name = Match(Match(clause->pattern, FunctionCall)->fn, Var)->name;
|
||||
else
|
||||
code_err(clause->pattern, "This is not a valid pattern for a %T enum", subject_t);
|
||||
|
||||
CORD valid_tags = CORD_EMPTY;
|
||||
for (match_t *m = matches; m; m = m->next) {
|
||||
if (streq(m->tag->name, tag_name)) {
|
||||
if (m->handled)
|
||||
code_err(clause->tag_name, "This tag was already handled earlier");
|
||||
code_err(clause->pattern, "This tag was already handled earlier");
|
||||
m->handled = true;
|
||||
goto found_matching_tag;
|
||||
}
|
||||
@ -1287,13 +1317,14 @@ type_t *get_type(env_t *env, ast_t *ast)
|
||||
valid_tags = CORD_cat(valid_tags, m->tag->name);
|
||||
}
|
||||
|
||||
code_err(clause->tag_name, "There is no tag '%s' for the type %T (valid tags: %s)",
|
||||
code_err(clause->pattern, "There is no tag '%s' for the type %T (valid tags: %s)",
|
||||
tag_name, subject_t, CORD_to_char_star(valid_tags));
|
||||
found_matching_tag:;
|
||||
}
|
||||
|
||||
for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
|
||||
type_t *clause_type = get_clause_type(env, subject_t, clause);
|
||||
env_t *clause_scope = when_clause_scope(env, subject_t, clause);
|
||||
type_t *clause_type = get_type(clause_scope, clause->body);
|
||||
type_t *merged = type_or_type(overall_t, clause_type);
|
||||
if (!merged)
|
||||
code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T",
|
||||
|
@ -21,6 +21,7 @@ PUREFUNC bool is_discardable(env_t *env, ast_t *ast);
|
||||
type_t *get_function_def_type(env_t *env, ast_t *ast);
|
||||
type_t *get_arg_type(env_t *env, arg_t *arg);
|
||||
type_t *get_arg_ast_type(env_t *env, arg_ast_t *arg);
|
||||
env_t *when_clause_scope(env_t *env, type_t *subject_t, when_clause_t *clause);
|
||||
type_t *get_clause_type(env_t *env, type_t *subject_t, when_clause_t *clause);
|
||||
PUREFUNC bool can_be_mutated(env_t *env, ast_t *ast);
|
||||
type_t *parse_type_string(env_t *env, const char *str);
|
||||
|
Loading…
Reference in New Issue
Block a user