diff options
| author | Bruce Hill <bruce@bruce-hill.com> | 2025-08-24 18:13:53 -0400 |
|---|---|---|
| committer | Bruce Hill <bruce@bruce-hill.com> | 2025-08-24 18:13:53 -0400 |
| commit | e56f4b447e785a7b0ad95d9f2c21e36e3d2bed2c (patch) | |
| tree | 786f56793781977d9dbf899fd59e2f904dee481e /src/compile | |
| parent | 222fbcd0027f085c48e93c6e70445699eec79d96 (diff) | |
Split 'when' into its own file
Diffstat (limited to 'src/compile')
| -rw-r--r-- | src/compile/expressions.c | 40 | ||||
| -rw-r--r-- | src/compile/statements.c | 123 | ||||
| -rw-r--r-- | src/compile/whens.c | 173 | ||||
| -rw-r--r-- | src/compile/whens.h | 9 |
4 files changed, 188 insertions, 157 deletions
diff --git a/src/compile/expressions.c b/src/compile/expressions.c index e485e450..404beff2 100644 --- a/src/compile/expressions.c +++ b/src/compile/expressions.c @@ -10,6 +10,7 @@ #include "../typecheck.h" #include "binops.h" #include "blocks.h" +#include "conditionals.h" #include "declarations.h" #include "enums.h" #include "functions.h" @@ -24,6 +25,7 @@ #include "tables.h" #include "text.h" #include "types.h" +#include "whens.h" public Text_t compile_maybe_incref(env_t *env, ast_t *ast, type_t *t) { @@ -320,42 +322,8 @@ Text_t compile(env_t *env, ast_t *ast) { case ExplicitlyTyped: { return compile_to_type(env, Match(ast, ExplicitlyTyped)->ast, get_type(env, ast)); } - case When: { - DeclareMatch(original, ast, When); - ast_t *when_var = WrapAST(ast, Var, .name = "when"); - when_clause_t *new_clauses = NULL; - type_t *subject_t = get_type(env, original->subject); - for (when_clause_t *clause = original->clauses; clause; clause = clause->next) { - type_t *clause_type = get_clause_type(env, subject_t, clause); - if (clause_type->tag == AbortType || clause_type->tag == ReturnType) { - new_clauses = - new (when_clause_t, .pattern = clause->pattern, .body = clause->body, .next = new_clauses); - } else { - ast_t *assign = WrapAST(clause->body, Assign, .targets = new (ast_list_t, .ast = when_var), - .values = new (ast_list_t, .ast = clause->body)); - new_clauses = new (when_clause_t, .pattern = clause->pattern, .body = assign, .next = new_clauses); - } - } - REVERSE_LIST(new_clauses); - ast_t *else_body = original->else_body; - if (else_body) { - type_t *clause_type = get_type(env, else_body); - if (clause_type->tag != AbortType && clause_type->tag != ReturnType) { - else_body = WrapAST(else_body, Assign, .targets = new (ast_list_t, .ast = when_var), - .values = new (ast_list_t, .ast = else_body)); - } - } - - type_t *t = get_type(env, ast); - env_t *when_env = fresh_scope(env); - set_binding(when_env, "when", t, Text("when")); - return Texts("({ ", compile_declaration(t, Text("when")), ";\n", - compile_statement(when_env, WrapAST(ast, When, .subject = original->subject, - .clauses = new_clauses, .else_body = else_body)), - "when; })"); - } - case If: { - } + case When: return compile_when_statement(env, ast); + case If: return compile_if_expression(env, ast); case Reduction: { DeclareMatch(reduction, ast, Reduction); ast_e op = reduction->op; diff --git a/src/compile/statements.c b/src/compile/statements.c index 642b900d..dde5facf 100644 --- a/src/compile/statements.c +++ b/src/compile/statements.c @@ -25,6 +25,7 @@ #include "statements.h" #include "text.h" #include "types.h" +#include "whens.h" typedef ast_t *(*comprehension_body_t)(ast_t *, ast_t *); @@ -37,127 +38,7 @@ Text_t with_source_info(env_t *env, ast_t *ast, Text_t code) { static Text_t _compile_statement(env_t *env, ast_t *ast) { switch (ast->tag) { - case When: { - // Typecheck to verify exhaustiveness: - type_t *result_t = get_type(env, ast); - (void)result_t; - - DeclareMatch(when, ast, When); - type_t *subject_t = get_type(env, when->subject); - - if (subject_t->tag != EnumType) { - Text_t prefix = EMPTY_TEXT, suffix = EMPTY_TEXT; - ast_t *subject = when->subject; - if (!is_idempotent(when->subject)) { - prefix = Texts("{\n", compile_declaration(subject_t, Text("_when_subject")), " = ", - compile(env, subject), ";\n"); - suffix = Text("}\n"); - subject = LiteralCode(Text("_when_subject"), .type = subject_t); - } - - Text_t code = EMPTY_TEXT; - for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { - ast_t *comparison = WrapAST(clause->pattern, Equals, .lhs = subject, .rhs = clause->pattern); - (void)get_type(env, comparison); - if (code.length > 0) code = Texts(code, "else "); - code = Texts(code, "if (", compile(env, comparison), ")", compile_statement(env, clause->body)); - } - if (when->else_body) code = Texts(code, "else ", compile_statement(env, when->else_body)); - code = Texts(prefix, code, suffix); - return code; - } - - DeclareMatch(enum_t, subject_t, EnumType); - - Text_t code; - if (enum_has_fields(subject_t)) - code = Texts("WHEN(", compile_type(subject_t), ", ", compile(env, when->subject), ", _when_subject, {\n"); - else code = Texts("switch(", compile(env, when->subject), ") {\n"); - - for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { - if (clause->pattern->tag == Var) { - const char *clause_tag_name = Match(clause->pattern, Var)->name; - type_t *clause_type = clause->body ? get_type(env, clause->body) : Type(VoidType); - code = Texts( - code, "case ", namespace_name(enum_t->env, enum_t->env->namespace, Texts("tag$", clause_tag_name)), - ": {\n", compile_inline_block(env, clause->body), - (clause_type->tag == ReturnType || clause_type->tag == AbortType) ? EMPTY_TEXT : Text("break;\n"), - "}\n"); - continue; - } - - if (clause->pattern->tag != FunctionCall || Match(clause->pattern, FunctionCall)->fn->tag != Var) - code_err(clause->pattern, "This is not a valid pattern for a ", type_to_str(subject_t), " enum type"); - - const char *clause_tag_name = Match(Match(clause->pattern, FunctionCall)->fn, Var)->name; - code = Texts(code, "case ", - namespace_name(enum_t->env, enum_t->env->namespace, Texts("tag$", clause_tag_name)), ": {\n"); - type_t *tag_type = NULL; - for (tag_t *tag = enum_t->tags; tag; tag = tag->next) { - if (streq(tag->name, clause_tag_name)) { - tag_type = tag->type; - break; - } - } - assert(tag_type); - env_t *scope = env; - - DeclareMatch(tag_struct, tag_type, StructType); - arg_ast_t *args = Match(clause->pattern, FunctionCall)->args; - if (args && !args->next && tag_struct->fields && tag_struct->fields->next) { - if (args->value->tag != Var) code_err(args->value, "This is not a valid variable to bind to"); - const char *var_name = Match(args->value, Var)->name; - if (!streq(var_name, "_")) { - Text_t var = Texts("_$", var_name); - code = Texts(code, compile_declaration(tag_type, var), " = _when_subject.", - valid_c_name(clause_tag_name), ";\n"); - scope = fresh_scope(scope); - set_binding(scope, Match(args->value, Var)->name, tag_type, EMPTY_TEXT); - } - } else if (args) { - scope = fresh_scope(scope); - arg_t *field = tag_struct->fields; - for (arg_ast_t *arg = args; arg || field; arg = arg->next) { - if (!arg) - code_err(ast, "The field ", type_to_str(subject_t), ".", clause_tag_name, ".", field->name, - " wasn't accounted for"); - if (!field) code_err(arg->value, "This is one more field than ", type_to_str(subject_t), " has"); - if (arg->name) code_err(arg->value, "Named arguments are not currently supported"); - - const char *var_name = Match(arg->value, Var)->name; - if (!streq(var_name, "_")) { - Text_t var = Texts("_$", var_name); - code = Texts(code, compile_declaration(field->type, var), " = _when_subject.", - valid_c_name(clause_tag_name), ".", valid_c_name(field->name), ";\n"); - set_binding(scope, Match(arg->value, Var)->name, field->type, var); - } - field = field->next; - } - } - if (clause->body->tag == Block) { - ast_list_t *statements = Match(clause->body, Block)->statements; - if (!statements || (statements->ast->tag == Pass && !statements->next)) - code = Texts(code, "break;\n}\n"); - else code = Texts(code, compile_inline_block(scope, clause->body), "\nbreak;\n}\n"); - } else { - code = Texts(code, compile_statement(scope, clause->body), "\nbreak;\n}\n"); - } - } - if (when->else_body) { - if (when->else_body->tag == Block) { - ast_list_t *statements = Match(when->else_body, Block)->statements; - if (!statements || (statements->ast->tag == Pass && !statements->next)) - code = Texts(code, "default: break;"); - else code = Texts(code, "default: {\n", compile_inline_block(env, when->else_body), "\nbreak;\n}\n"); - } else { - code = Texts(code, "default: {\n", compile_statement(env, when->else_body), "\nbreak;\n}\n"); - } - } else { - code = Texts(code, "default: errx(1, \"Invalid tag!\");\n"); - } - code = Texts(code, "\n}", enum_has_fields(subject_t) ? Text(")") : EMPTY_TEXT, "\n"); - return code; - } + case When: return compile_when_statement(env, ast); case DocTest: { DeclareMatch(test, ast, DocTest); type_t *expr_t = get_type(env, test->expr); diff --git a/src/compile/whens.c b/src/compile/whens.c new file mode 100644 index 00000000..6af2ecf1 --- /dev/null +++ b/src/compile/whens.c @@ -0,0 +1,173 @@ +// This file defines how to compile 'when' statements/expressions + +#include "whens.h" +#include "../ast.h" +#include "../config.h" +#include "../environment.h" +#include "../naming.h" +#include "../stdlib/datatypes.h" +#include "../stdlib/text.h" +#include "../stdlib/util.h" +#include "../typecheck.h" +#include "blocks.h" +#include "declarations.h" +#include "expressions.h" +#include "statements.h" +#include "types.h" + +public +Text_t compile_when_statement(env_t *env, ast_t *ast) { + // Typecheck to verify exhaustiveness: + type_t *result_t = get_type(env, ast); + (void)result_t; + + DeclareMatch(when, ast, When); + type_t *subject_t = get_type(env, when->subject); + + if (subject_t->tag != EnumType) { + Text_t prefix = EMPTY_TEXT, suffix = EMPTY_TEXT; + ast_t *subject = when->subject; + if (!is_idempotent(when->subject)) { + prefix = Texts("{\n", compile_declaration(subject_t, Text("_when_subject")), " = ", compile(env, subject), + ";\n"); + suffix = Text("}\n"); + subject = LiteralCode(Text("_when_subject"), .type = subject_t); + } + + Text_t code = EMPTY_TEXT; + for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { + ast_t *comparison = WrapAST(clause->pattern, Equals, .lhs = subject, .rhs = clause->pattern); + (void)get_type(env, comparison); + if (code.length > 0) code = Texts(code, "else "); + code = Texts(code, "if (", compile(env, comparison), ")", compile_statement(env, clause->body)); + } + if (when->else_body) code = Texts(code, "else ", compile_statement(env, when->else_body)); + code = Texts(prefix, code, suffix); + return code; + } + + DeclareMatch(enum_t, subject_t, EnumType); + + Text_t code; + if (enum_has_fields(subject_t)) + code = Texts("WHEN(", compile_type(subject_t), ", ", compile(env, when->subject), ", _when_subject, {\n"); + else code = Texts("switch(", compile(env, when->subject), ") {\n"); + + for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { + if (clause->pattern->tag == Var) { + const char *clause_tag_name = Match(clause->pattern, Var)->name; + type_t *clause_type = clause->body ? get_type(env, clause->body) : Type(VoidType); + code = Texts( + code, "case ", namespace_name(enum_t->env, enum_t->env->namespace, Texts("tag$", clause_tag_name)), + ": {\n", compile_inline_block(env, clause->body), + (clause_type->tag == ReturnType || clause_type->tag == AbortType) ? EMPTY_TEXT : Text("break;\n"), + "}\n"); + continue; + } + + if (clause->pattern->tag != FunctionCall || Match(clause->pattern, FunctionCall)->fn->tag != Var) + code_err(clause->pattern, "This is not a valid pattern for a ", type_to_str(subject_t), " enum type"); + + const char *clause_tag_name = Match(Match(clause->pattern, FunctionCall)->fn, Var)->name; + code = Texts(code, "case ", namespace_name(enum_t->env, enum_t->env->namespace, Texts("tag$", clause_tag_name)), + ": {\n"); + type_t *tag_type = NULL; + for (tag_t *tag = enum_t->tags; tag; tag = tag->next) { + if (streq(tag->name, clause_tag_name)) { + tag_type = tag->type; + break; + } + } + assert(tag_type); + env_t *scope = env; + + DeclareMatch(tag_struct, tag_type, StructType); + arg_ast_t *args = Match(clause->pattern, FunctionCall)->args; + if (args && !args->next && tag_struct->fields && tag_struct->fields->next) { + if (args->value->tag != Var) code_err(args->value, "This is not a valid variable to bind to"); + const char *var_name = Match(args->value, Var)->name; + if (!streq(var_name, "_")) { + Text_t var = Texts("_$", var_name); + code = Texts(code, compile_declaration(tag_type, var), " = _when_subject.", + valid_c_name(clause_tag_name), ";\n"); + scope = fresh_scope(scope); + set_binding(scope, Match(args->value, Var)->name, tag_type, EMPTY_TEXT); + } + } else if (args) { + scope = fresh_scope(scope); + arg_t *field = tag_struct->fields; + for (arg_ast_t *arg = args; arg || field; arg = arg->next) { + if (!arg) + code_err(ast, "The field ", type_to_str(subject_t), ".", clause_tag_name, ".", field->name, + " wasn't accounted for"); + if (!field) code_err(arg->value, "This is one more field than ", type_to_str(subject_t), " has"); + if (arg->name) code_err(arg->value, "Named arguments are not currently supported"); + + const char *var_name = Match(arg->value, Var)->name; + if (!streq(var_name, "_")) { + Text_t var = Texts("_$", var_name); + code = Texts(code, compile_declaration(field->type, var), " = _when_subject.", + valid_c_name(clause_tag_name), ".", valid_c_name(field->name), ";\n"); + set_binding(scope, Match(arg->value, Var)->name, field->type, var); + } + field = field->next; + } + } + if (clause->body->tag == Block) { + ast_list_t *statements = Match(clause->body, Block)->statements; + if (!statements || (statements->ast->tag == Pass && !statements->next)) code = Texts(code, "break;\n}\n"); + else code = Texts(code, compile_inline_block(scope, clause->body), "\nbreak;\n}\n"); + } else { + code = Texts(code, compile_statement(scope, clause->body), "\nbreak;\n}\n"); + } + } + if (when->else_body) { + if (when->else_body->tag == Block) { + ast_list_t *statements = Match(when->else_body, Block)->statements; + if (!statements || (statements->ast->tag == Pass && !statements->next)) + code = Texts(code, "default: break;"); + else code = Texts(code, "default: {\n", compile_inline_block(env, when->else_body), "\nbreak;\n}\n"); + } else { + code = Texts(code, "default: {\n", compile_statement(env, when->else_body), "\nbreak;\n}\n"); + } + } else { + code = Texts(code, "default: errx(1, \"Invalid tag!\");\n"); + } + code = Texts(code, "\n}", enum_has_fields(subject_t) ? Text(")") : EMPTY_TEXT, "\n"); + return code; +} + +public +Text_t compile_when_expression(env_t *env, ast_t *ast) { + DeclareMatch(original, ast, When); + ast_t *when_var = WrapAST(ast, Var, .name = "when"); + when_clause_t *new_clauses = NULL; + type_t *subject_t = get_type(env, original->subject); + for (when_clause_t *clause = original->clauses; clause; clause = clause->next) { + type_t *clause_type = get_clause_type(env, subject_t, clause); + if (clause_type->tag == AbortType || clause_type->tag == ReturnType) { + new_clauses = new (when_clause_t, .pattern = clause->pattern, .body = clause->body, .next = new_clauses); + } else { + ast_t *assign = WrapAST(clause->body, Assign, .targets = new (ast_list_t, .ast = when_var), + .values = new (ast_list_t, .ast = clause->body)); + new_clauses = new (when_clause_t, .pattern = clause->pattern, .body = assign, .next = new_clauses); + } + } + REVERSE_LIST(new_clauses); + ast_t *else_body = original->else_body; + if (else_body) { + type_t *clause_type = get_type(env, else_body); + if (clause_type->tag != AbortType && clause_type->tag != ReturnType) { + else_body = WrapAST(else_body, Assign, .targets = new (ast_list_t, .ast = when_var), + .values = new (ast_list_t, .ast = else_body)); + } + } + + type_t *t = get_type(env, ast); + env_t *when_env = fresh_scope(env); + set_binding(when_env, "when", t, Text("when")); + return Texts("({ ", compile_declaration(t, Text("when")), ";\n", + compile_statement(when_env, WrapAST(ast, When, .subject = original->subject, .clauses = new_clauses, + .else_body = else_body)), + "when; })"); +} diff --git a/src/compile/whens.h b/src/compile/whens.h new file mode 100644 index 00000000..473124d5 --- /dev/null +++ b/src/compile/whens.h @@ -0,0 +1,9 @@ +// This file defines how to compile 'when' statements/expressions +#pragma once + +#include "../ast.h" +#include "../environment.h" +#include "../stdlib/datatypes.h" + +Text_t compile_when_statement(env_t *env, ast_t *ast); +Text_t compile_when_expression(env_t *env, ast_t *ast); |
