aboutsummaryrefslogtreecommitdiff
path: root/compile.c
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2025-03-06 18:37:08 -0500
committerBruce Hill <bruce@bruce-hill.com>2025-03-06 18:37:08 -0500
commit73e559fbe4182828742ac1b1d108bcdc42bc46d6 (patch)
tree22c4d6245114cf3ba72701dafe18b84ce72e3f66 /compile.c
parent09423f6d42d86c62beefa4607fba41e3698f1850 (diff)
Support 'when' for literal values with equality checking
Diffstat (limited to 'compile.c')
-rw-r--r--compile.c120
1 files changed, 81 insertions, 39 deletions
diff --git a/compile.c b/compile.c
index aa8f6456..30e20e5b 100644
--- a/compile.c
+++ b/compile.c
@@ -360,9 +360,27 @@ static void add_closed_vars(Table_t *closed_vars, env_t *enclosing_scope, env_t
add_closed_vars(closed_vars, enclosing_scope, env, when->subject);
type_t *subject_t = get_type(env, when->subject);
+ if (subject_t->tag != EnumType) {
+ for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
+ add_closed_vars(closed_vars, enclosing_scope, env, clause->pattern);
+ add_closed_vars(closed_vars, enclosing_scope, env, clause->body);
+ }
+
+ if (when->else_body)
+ add_closed_vars(closed_vars, enclosing_scope, env, when->else_body);
+ return;
+ }
+
auto enum_t = Match(subject_t, EnumType);
for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
- const char *clause_tag_name = Match(clause->tag_name, Var)->name;
+ const char *clause_tag_name;
+ if (clause->pattern->tag == Var)
+ clause_tag_name = Match(clause->pattern, Var)->name;
+ else if (clause->pattern->tag == FunctionCall && Match(clause->pattern, FunctionCall)->fn->tag == Var)
+ clause_tag_name = Match(Match(clause->pattern, FunctionCall)->fn, Var)->name;
+ else
+ code_err(clause->pattern, "This is not a valid pattern for a %T enum", subject_t);
+
type_t *tag_type = NULL;
for (tag_t *tag = enum_t->tags; tag; tag = tag->next) {
if (streq(tag->name, clause_tag_name)) {
@@ -371,26 +389,7 @@ static void add_closed_vars(Table_t *closed_vars, env_t *enclosing_scope, env_t
}
}
assert(tag_type);
- env_t *scope = env;
-
- auto tag_struct = Match(tag_type, StructType);
- if (clause->args && !clause->args->next && tag_struct->fields && tag_struct->fields->next) {
- scope = fresh_scope(scope);
- set_binding(scope, Match(clause->args->ast, Var)->name, tag_type, CORD_EMPTY);
- } else if (clause->args) {
- scope = fresh_scope(scope);
- ast_list_t *var = clause->args;
- arg_t *field = tag_struct->fields;
- while (var || field) {
- if (!var)
- code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, clause_tag_name, field->name);
- if (!field)
- code_err(var->ast, "This is one more field than %T has", subject_t);
- set_binding(scope, Match(var->ast, Var)->name, field->type, CORD_EMPTY);
- var = var->next;
- field = field->next;
- }
- }
+ env_t *scope = when_clause_scope(env, subject_t, clause);
add_closed_vars(closed_vars, enclosing_scope, scope, clause->body);
}
if (when->else_body)
@@ -752,11 +751,44 @@ static CORD _compile_statement(env_t *env, ast_t *ast)
auto when = Match(ast, When);
type_t *subject_t = get_type(env, when->subject);
+ if (subject_t->tag != EnumType) {
+ CORD prefix = CORD_EMPTY, suffix = CORD_EMPTY;
+ ast_t *subject = when->subject;
+ if (!is_idempotent(when->subject)) {
+ prefix = CORD_all("{\n", compile_declaration(subject_t, "_when_subject"), " = ", compile(env, subject), ";\n");
+ suffix = "}\n";
+ subject = WrapAST(subject, InlineCCode, .type=subject_t, .code="_when_subject");
+ }
+
+ CORD code = CORD_EMPTY;
+ for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
+ ast_t *comparison = WrapAST(clause->pattern, BinaryOp, .lhs=subject, .op=BINOP_EQ, .rhs=clause->pattern);
+ if (code != CORD_EMPTY)
+ code = CORD_all(code, "else ");
+ code = CORD_all(code, "if (", compile(env, comparison), ")", compile_statement(env, clause->body));
+ }
+ if (when->else_body)
+ code = CORD_all(code, "else ", compile_statement(env, when->else_body));
+ code = CORD_all(prefix, code, suffix);
+ return code;
+ }
+
auto enum_t = Match(subject_t, EnumType);
CORD code = CORD_all("{ ", compile_type(subject_t), " subject = ", compile(env, when->subject), ";\n"
"switch (subject.tag) {");
for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
- const char *clause_tag_name = Match(clause->tag_name, Var)->name;
+ if (clause->pattern->tag == Var) {
+ const char *clause_tag_name = Match(clause->pattern, Var)->name;
+ code = CORD_all(code, "case ", namespace_prefix(enum_t->env, enum_t->env->namespace), "tag$", clause_tag_name, ": {\n",
+ compile_statement(env, clause->body),
+ "}\n");
+ continue;
+ }
+
+ if (clause->pattern->tag != FunctionCall || Match(clause->pattern, FunctionCall)->fn->tag != Var)
+ code_err(clause->pattern, "This is not a valid pattern for a %T enum type", subject_t);
+
+ const char *clause_tag_name = Match(Match(clause->pattern, FunctionCall)->fn, Var)->name;
code = CORD_all(code, "case ", namespace_prefix(enum_t->env, enum_t->env->namespace), "tag$", clause_tag_name, ": {\n");
type_t *tag_type = NULL;
for (tag_t *tag = enum_t->tags; tag; tag = tag->next) {
@@ -769,22 +801,32 @@ static CORD _compile_statement(env_t *env, ast_t *ast)
env_t *scope = env;
auto tag_struct = Match(tag_type, StructType);
- if (clause->args && !clause->args->next && tag_struct->fields && tag_struct->fields->next) {
- code = CORD_all(code, compile_declaration(tag_type, compile(env, clause->args->ast)), " = subject.$", clause_tag_name, ";\n");
- scope = fresh_scope(scope);
- set_binding(scope, Match(clause->args->ast, Var)->name, tag_type, CORD_EMPTY);
- } else if (clause->args) {
+ arg_ast_t *args = Match(clause->pattern, FunctionCall)->args;
+ if (args && !args->next && tag_struct->fields && tag_struct->fields->next) {
+ if (args->value->tag != Var)
+ code_err(args->value, "This is not a valid variable to bind to");
+ const char *var_name = Match(args->value, Var)->name;
+ if (!streq(var_name, "_")) {
+ code = CORD_all(code, compile_declaration(tag_type, compile(env, args->value)), " = subject.$", clause_tag_name, ";\n");
+ scope = fresh_scope(scope);
+ set_binding(scope, Match(args->value, Var)->name, tag_type, CORD_EMPTY);
+ }
+ } else if (args) {
scope = fresh_scope(scope);
- ast_list_t *var = clause->args;
arg_t *field = tag_struct->fields;
- while (var || field) {
- if (!var)
- code_err(clause->tag_name, "The field %T.%s.%s wasn't accounted for", subject_t, clause_tag_name, field->name);
+ for (arg_ast_t *arg = args; arg || field; arg = arg->next) {
+ if (!arg)
+ code_err(ast, "The field %T.%s.%s wasn't accounted for", subject_t, clause_tag_name, field->name);
if (!field)
- code_err(var->ast, "This is one more field than %T has", subject_t);
- code = CORD_all(code, compile_declaration(field->type, compile(env, var->ast)), " = subject.$", clause_tag_name, ".$", field->name, ";\n");
- set_binding(scope, Match(var->ast, Var)->name, field->type, CORD_EMPTY);
- var = var->next;
+ code_err(arg->value, "This is one more field than %T has", subject_t);
+ if (arg->name)
+ code_err(arg->value, "Named arguments are not currently supported");
+
+ const char *var_name = Match(arg->value, Var)->name;
+ if (!streq(var_name, "_")) {
+ code = CORD_all(code, compile_declaration(field->type, compile(env, arg->value)), " = subject.$", clause_tag_name, ".$", field->name, ";\n");
+ set_binding(scope, Match(arg->value, Var)->name, field->type, CORD_EMPTY);
+ }
field = field->next;
}
}
@@ -1160,7 +1202,7 @@ static CORD _compile_statement(env_t *env, ast_t *ast)
const char *target = Match(ast, Skip)->target;
for (loop_ctx_t *ctx = env->loop_ctx; ctx; ctx = ctx->next) {
bool matched = !target || CORD_cmp(target, ctx->loop_name) == 0;
- for (ast_list_t *var = ctx->loop_vars; var && !matched; var = var->next)
+ for (ast_list_t *var = ctx->loop_vars; var && !matched; var = var ? var->next : NULL)
matched = (CORD_cmp(target, Match(var->ast, Var)->name) == 0);
if (matched) {
@@ -1189,7 +1231,7 @@ static CORD _compile_statement(env_t *env, ast_t *ast)
const char *target = Match(ast, Stop)->target;
for (loop_ctx_t *ctx = env->loop_ctx; ctx; ctx = ctx->next) {
bool matched = !target || CORD_cmp(target, ctx->loop_name) == 0;
- for (ast_list_t *var = ctx->loop_vars; var && !matched; var = var->next)
+ for (ast_list_t *var = ctx->loop_vars; var && !matched; var = var ? var->next : var)
matched = (CORD_cmp(target, Match(var->ast, Var)->name) == 0);
if (matched) {
@@ -3443,12 +3485,12 @@ CORD compile(env_t *env, ast_t *ast)
for (when_clause_t *clause = original->clauses; clause; clause = clause->next) {
type_t *clause_type = get_clause_type(env, subject_t, clause);
if (clause_type->tag == AbortType || clause_type->tag == ReturnType) {
- new_clauses = new(when_clause_t, .tag_name=clause->tag_name, .args=clause->args, .body=clause->body, .next=new_clauses);
+ new_clauses = new(when_clause_t, .pattern=clause->pattern, .body=clause->body, .next=new_clauses);
} else {
ast_t *assign = WrapAST(clause->body, Assign,
.targets=new(ast_list_t, .ast=when_var),
.values=new(ast_list_t, .ast=clause->body));
- new_clauses = new(when_clause_t, .tag_name=clause->tag_name, .args=clause->args, .body=assign, .next=new_clauses);
+ new_clauses = new(when_clause_t, .pattern=clause->pattern, .body=assign, .next=new_clauses);
}
}
REVERSE_LIST(new_clauses);