aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/compile/expressions.c40
-rw-r--r--src/compile/statements.c123
-rw-r--r--src/compile/whens.c173
-rw-r--r--src/compile/whens.h9
4 files changed, 188 insertions, 157 deletions
diff --git a/src/compile/expressions.c b/src/compile/expressions.c
index e485e450..404beff2 100644
--- a/src/compile/expressions.c
+++ b/src/compile/expressions.c
@@ -10,6 +10,7 @@
#include "../typecheck.h"
#include "binops.h"
#include "blocks.h"
+#include "conditionals.h"
#include "declarations.h"
#include "enums.h"
#include "functions.h"
@@ -24,6 +25,7 @@
#include "tables.h"
#include "text.h"
#include "types.h"
+#include "whens.h"
public
Text_t compile_maybe_incref(env_t *env, ast_t *ast, type_t *t) {
@@ -320,42 +322,8 @@ Text_t compile(env_t *env, ast_t *ast) {
case ExplicitlyTyped: {
return compile_to_type(env, Match(ast, ExplicitlyTyped)->ast, get_type(env, ast));
}
- case When: {
- DeclareMatch(original, ast, When);
- ast_t *when_var = WrapAST(ast, Var, .name = "when");
- when_clause_t *new_clauses = NULL;
- type_t *subject_t = get_type(env, original->subject);
- for (when_clause_t *clause = original->clauses; clause; clause = clause->next) {
- type_t *clause_type = get_clause_type(env, subject_t, clause);
- if (clause_type->tag == AbortType || clause_type->tag == ReturnType) {
- new_clauses =
- new (when_clause_t, .pattern = clause->pattern, .body = clause->body, .next = new_clauses);
- } else {
- ast_t *assign = WrapAST(clause->body, Assign, .targets = new (ast_list_t, .ast = when_var),
- .values = new (ast_list_t, .ast = clause->body));
- new_clauses = new (when_clause_t, .pattern = clause->pattern, .body = assign, .next = new_clauses);
- }
- }
- REVERSE_LIST(new_clauses);
- ast_t *else_body = original->else_body;
- if (else_body) {
- type_t *clause_type = get_type(env, else_body);
- if (clause_type->tag != AbortType && clause_type->tag != ReturnType) {
- else_body = WrapAST(else_body, Assign, .targets = new (ast_list_t, .ast = when_var),
- .values = new (ast_list_t, .ast = else_body));
- }
- }
-
- type_t *t = get_type(env, ast);
- env_t *when_env = fresh_scope(env);
- set_binding(when_env, "when", t, Text("when"));
- return Texts("({ ", compile_declaration(t, Text("when")), ";\n",
- compile_statement(when_env, WrapAST(ast, When, .subject = original->subject,
- .clauses = new_clauses, .else_body = else_body)),
- "when; })");
- }
- case If: {
- }
+ case When: return compile_when_statement(env, ast);
+ case If: return compile_if_expression(env, ast);
case Reduction: {
DeclareMatch(reduction, ast, Reduction);
ast_e op = reduction->op;
diff --git a/src/compile/statements.c b/src/compile/statements.c
index 642b900d..dde5facf 100644
--- a/src/compile/statements.c
+++ b/src/compile/statements.c
@@ -25,6 +25,7 @@
#include "statements.h"
#include "text.h"
#include "types.h"
+#include "whens.h"
typedef ast_t *(*comprehension_body_t)(ast_t *, ast_t *);
@@ -37,127 +38,7 @@ Text_t with_source_info(env_t *env, ast_t *ast, Text_t code) {
static Text_t _compile_statement(env_t *env, ast_t *ast) {
switch (ast->tag) {
- case When: {
- // Typecheck to verify exhaustiveness:
- type_t *result_t = get_type(env, ast);
- (void)result_t;
-
- DeclareMatch(when, ast, When);
- type_t *subject_t = get_type(env, when->subject);
-
- if (subject_t->tag != EnumType) {
- Text_t prefix = EMPTY_TEXT, suffix = EMPTY_TEXT;
- ast_t *subject = when->subject;
- if (!is_idempotent(when->subject)) {
- prefix = Texts("{\n", compile_declaration(subject_t, Text("_when_subject")), " = ",
- compile(env, subject), ";\n");
- suffix = Text("}\n");
- subject = LiteralCode(Text("_when_subject"), .type = subject_t);
- }
-
- Text_t code = EMPTY_TEXT;
- for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
- ast_t *comparison = WrapAST(clause->pattern, Equals, .lhs = subject, .rhs = clause->pattern);
- (void)get_type(env, comparison);
- if (code.length > 0) code = Texts(code, "else ");
- code = Texts(code, "if (", compile(env, comparison), ")", compile_statement(env, clause->body));
- }
- if (when->else_body) code = Texts(code, "else ", compile_statement(env, when->else_body));
- code = Texts(prefix, code, suffix);
- return code;
- }
-
- DeclareMatch(enum_t, subject_t, EnumType);
-
- Text_t code;
- if (enum_has_fields(subject_t))
- code = Texts("WHEN(", compile_type(subject_t), ", ", compile(env, when->subject), ", _when_subject, {\n");
- else code = Texts("switch(", compile(env, when->subject), ") {\n");
-
- for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
- if (clause->pattern->tag == Var) {
- const char *clause_tag_name = Match(clause->pattern, Var)->name;
- type_t *clause_type = clause->body ? get_type(env, clause->body) : Type(VoidType);
- code = Texts(
- code, "case ", namespace_name(enum_t->env, enum_t->env->namespace, Texts("tag$", clause_tag_name)),
- ": {\n", compile_inline_block(env, clause->body),
- (clause_type->tag == ReturnType || clause_type->tag == AbortType) ? EMPTY_TEXT : Text("break;\n"),
- "}\n");
- continue;
- }
-
- if (clause->pattern->tag != FunctionCall || Match(clause->pattern, FunctionCall)->fn->tag != Var)
- code_err(clause->pattern, "This is not a valid pattern for a ", type_to_str(subject_t), " enum type");
-
- const char *clause_tag_name = Match(Match(clause->pattern, FunctionCall)->fn, Var)->name;
- code = Texts(code, "case ",
- namespace_name(enum_t->env, enum_t->env->namespace, Texts("tag$", clause_tag_name)), ": {\n");
- type_t *tag_type = NULL;
- for (tag_t *tag = enum_t->tags; tag; tag = tag->next) {
- if (streq(tag->name, clause_tag_name)) {
- tag_type = tag->type;
- break;
- }
- }
- assert(tag_type);
- env_t *scope = env;
-
- DeclareMatch(tag_struct, tag_type, StructType);
- arg_ast_t *args = Match(clause->pattern, FunctionCall)->args;
- if (args && !args->next && tag_struct->fields && tag_struct->fields->next) {
- if (args->value->tag != Var) code_err(args->value, "This is not a valid variable to bind to");
- const char *var_name = Match(args->value, Var)->name;
- if (!streq(var_name, "_")) {
- Text_t var = Texts("_$", var_name);
- code = Texts(code, compile_declaration(tag_type, var), " = _when_subject.",
- valid_c_name(clause_tag_name), ";\n");
- scope = fresh_scope(scope);
- set_binding(scope, Match(args->value, Var)->name, tag_type, EMPTY_TEXT);
- }
- } else if (args) {
- scope = fresh_scope(scope);
- arg_t *field = tag_struct->fields;
- for (arg_ast_t *arg = args; arg || field; arg = arg->next) {
- if (!arg)
- code_err(ast, "The field ", type_to_str(subject_t), ".", clause_tag_name, ".", field->name,
- " wasn't accounted for");
- if (!field) code_err(arg->value, "This is one more field than ", type_to_str(subject_t), " has");
- if (arg->name) code_err(arg->value, "Named arguments are not currently supported");
-
- const char *var_name = Match(arg->value, Var)->name;
- if (!streq(var_name, "_")) {
- Text_t var = Texts("_$", var_name);
- code = Texts(code, compile_declaration(field->type, var), " = _when_subject.",
- valid_c_name(clause_tag_name), ".", valid_c_name(field->name), ";\n");
- set_binding(scope, Match(arg->value, Var)->name, field->type, var);
- }
- field = field->next;
- }
- }
- if (clause->body->tag == Block) {
- ast_list_t *statements = Match(clause->body, Block)->statements;
- if (!statements || (statements->ast->tag == Pass && !statements->next))
- code = Texts(code, "break;\n}\n");
- else code = Texts(code, compile_inline_block(scope, clause->body), "\nbreak;\n}\n");
- } else {
- code = Texts(code, compile_statement(scope, clause->body), "\nbreak;\n}\n");
- }
- }
- if (when->else_body) {
- if (when->else_body->tag == Block) {
- ast_list_t *statements = Match(when->else_body, Block)->statements;
- if (!statements || (statements->ast->tag == Pass && !statements->next))
- code = Texts(code, "default: break;");
- else code = Texts(code, "default: {\n", compile_inline_block(env, when->else_body), "\nbreak;\n}\n");
- } else {
- code = Texts(code, "default: {\n", compile_statement(env, when->else_body), "\nbreak;\n}\n");
- }
- } else {
- code = Texts(code, "default: errx(1, \"Invalid tag!\");\n");
- }
- code = Texts(code, "\n}", enum_has_fields(subject_t) ? Text(")") : EMPTY_TEXT, "\n");
- return code;
- }
+ case When: return compile_when_statement(env, ast);
case DocTest: {
DeclareMatch(test, ast, DocTest);
type_t *expr_t = get_type(env, test->expr);
diff --git a/src/compile/whens.c b/src/compile/whens.c
new file mode 100644
index 00000000..6af2ecf1
--- /dev/null
+++ b/src/compile/whens.c
@@ -0,0 +1,173 @@
+// This file defines how to compile 'when' statements/expressions
+
+#include "whens.h"
+#include "../ast.h"
+#include "../config.h"
+#include "../environment.h"
+#include "../naming.h"
+#include "../stdlib/datatypes.h"
+#include "../stdlib/text.h"
+#include "../stdlib/util.h"
+#include "../typecheck.h"
+#include "blocks.h"
+#include "declarations.h"
+#include "expressions.h"
+#include "statements.h"
+#include "types.h"
+
+public
+Text_t compile_when_statement(env_t *env, ast_t *ast) {
+ // Typecheck to verify exhaustiveness:
+ type_t *result_t = get_type(env, ast);
+ (void)result_t;
+
+ DeclareMatch(when, ast, When);
+ type_t *subject_t = get_type(env, when->subject);
+
+ if (subject_t->tag != EnumType) {
+ Text_t prefix = EMPTY_TEXT, suffix = EMPTY_TEXT;
+ ast_t *subject = when->subject;
+ if (!is_idempotent(when->subject)) {
+ prefix = Texts("{\n", compile_declaration(subject_t, Text("_when_subject")), " = ", compile(env, subject),
+ ";\n");
+ suffix = Text("}\n");
+ subject = LiteralCode(Text("_when_subject"), .type = subject_t);
+ }
+
+ Text_t code = EMPTY_TEXT;
+ for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
+ ast_t *comparison = WrapAST(clause->pattern, Equals, .lhs = subject, .rhs = clause->pattern);
+ (void)get_type(env, comparison);
+ if (code.length > 0) code = Texts(code, "else ");
+ code = Texts(code, "if (", compile(env, comparison), ")", compile_statement(env, clause->body));
+ }
+ if (when->else_body) code = Texts(code, "else ", compile_statement(env, when->else_body));
+ code = Texts(prefix, code, suffix);
+ return code;
+ }
+
+ DeclareMatch(enum_t, subject_t, EnumType);
+
+ Text_t code;
+ if (enum_has_fields(subject_t))
+ code = Texts("WHEN(", compile_type(subject_t), ", ", compile(env, when->subject), ", _when_subject, {\n");
+ else code = Texts("switch(", compile(env, when->subject), ") {\n");
+
+ for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
+ if (clause->pattern->tag == Var) {
+ const char *clause_tag_name = Match(clause->pattern, Var)->name;
+ type_t *clause_type = clause->body ? get_type(env, clause->body) : Type(VoidType);
+ code = Texts(
+ code, "case ", namespace_name(enum_t->env, enum_t->env->namespace, Texts("tag$", clause_tag_name)),
+ ": {\n", compile_inline_block(env, clause->body),
+ (clause_type->tag == ReturnType || clause_type->tag == AbortType) ? EMPTY_TEXT : Text("break;\n"),
+ "}\n");
+ continue;
+ }
+
+ if (clause->pattern->tag != FunctionCall || Match(clause->pattern, FunctionCall)->fn->tag != Var)
+ code_err(clause->pattern, "This is not a valid pattern for a ", type_to_str(subject_t), " enum type");
+
+ const char *clause_tag_name = Match(Match(clause->pattern, FunctionCall)->fn, Var)->name;
+ code = Texts(code, "case ", namespace_name(enum_t->env, enum_t->env->namespace, Texts("tag$", clause_tag_name)),
+ ": {\n");
+ type_t *tag_type = NULL;
+ for (tag_t *tag = enum_t->tags; tag; tag = tag->next) {
+ if (streq(tag->name, clause_tag_name)) {
+ tag_type = tag->type;
+ break;
+ }
+ }
+ assert(tag_type);
+ env_t *scope = env;
+
+ DeclareMatch(tag_struct, tag_type, StructType);
+ arg_ast_t *args = Match(clause->pattern, FunctionCall)->args;
+ if (args && !args->next && tag_struct->fields && tag_struct->fields->next) {
+ if (args->value->tag != Var) code_err(args->value, "This is not a valid variable to bind to");
+ const char *var_name = Match(args->value, Var)->name;
+ if (!streq(var_name, "_")) {
+ Text_t var = Texts("_$", var_name);
+ code = Texts(code, compile_declaration(tag_type, var), " = _when_subject.",
+ valid_c_name(clause_tag_name), ";\n");
+ scope = fresh_scope(scope);
+ set_binding(scope, Match(args->value, Var)->name, tag_type, EMPTY_TEXT);
+ }
+ } else if (args) {
+ scope = fresh_scope(scope);
+ arg_t *field = tag_struct->fields;
+ for (arg_ast_t *arg = args; arg || field; arg = arg->next) {
+ if (!arg)
+ code_err(ast, "The field ", type_to_str(subject_t), ".", clause_tag_name, ".", field->name,
+ " wasn't accounted for");
+ if (!field) code_err(arg->value, "This is one more field than ", type_to_str(subject_t), " has");
+ if (arg->name) code_err(arg->value, "Named arguments are not currently supported");
+
+ const char *var_name = Match(arg->value, Var)->name;
+ if (!streq(var_name, "_")) {
+ Text_t var = Texts("_$", var_name);
+ code = Texts(code, compile_declaration(field->type, var), " = _when_subject.",
+ valid_c_name(clause_tag_name), ".", valid_c_name(field->name), ";\n");
+ set_binding(scope, Match(arg->value, Var)->name, field->type, var);
+ }
+ field = field->next;
+ }
+ }
+ if (clause->body->tag == Block) {
+ ast_list_t *statements = Match(clause->body, Block)->statements;
+ if (!statements || (statements->ast->tag == Pass && !statements->next)) code = Texts(code, "break;\n}\n");
+ else code = Texts(code, compile_inline_block(scope, clause->body), "\nbreak;\n}\n");
+ } else {
+ code = Texts(code, compile_statement(scope, clause->body), "\nbreak;\n}\n");
+ }
+ }
+ if (when->else_body) {
+ if (when->else_body->tag == Block) {
+ ast_list_t *statements = Match(when->else_body, Block)->statements;
+ if (!statements || (statements->ast->tag == Pass && !statements->next))
+ code = Texts(code, "default: break;");
+ else code = Texts(code, "default: {\n", compile_inline_block(env, when->else_body), "\nbreak;\n}\n");
+ } else {
+ code = Texts(code, "default: {\n", compile_statement(env, when->else_body), "\nbreak;\n}\n");
+ }
+ } else {
+ code = Texts(code, "default: errx(1, \"Invalid tag!\");\n");
+ }
+ code = Texts(code, "\n}", enum_has_fields(subject_t) ? Text(")") : EMPTY_TEXT, "\n");
+ return code;
+}
+
+public
+Text_t compile_when_expression(env_t *env, ast_t *ast) {
+ DeclareMatch(original, ast, When);
+ ast_t *when_var = WrapAST(ast, Var, .name = "when");
+ when_clause_t *new_clauses = NULL;
+ type_t *subject_t = get_type(env, original->subject);
+ for (when_clause_t *clause = original->clauses; clause; clause = clause->next) {
+ type_t *clause_type = get_clause_type(env, subject_t, clause);
+ if (clause_type->tag == AbortType || clause_type->tag == ReturnType) {
+ new_clauses = new (when_clause_t, .pattern = clause->pattern, .body = clause->body, .next = new_clauses);
+ } else {
+ ast_t *assign = WrapAST(clause->body, Assign, .targets = new (ast_list_t, .ast = when_var),
+ .values = new (ast_list_t, .ast = clause->body));
+ new_clauses = new (when_clause_t, .pattern = clause->pattern, .body = assign, .next = new_clauses);
+ }
+ }
+ REVERSE_LIST(new_clauses);
+ ast_t *else_body = original->else_body;
+ if (else_body) {
+ type_t *clause_type = get_type(env, else_body);
+ if (clause_type->tag != AbortType && clause_type->tag != ReturnType) {
+ else_body = WrapAST(else_body, Assign, .targets = new (ast_list_t, .ast = when_var),
+ .values = new (ast_list_t, .ast = else_body));
+ }
+ }
+
+ type_t *t = get_type(env, ast);
+ env_t *when_env = fresh_scope(env);
+ set_binding(when_env, "when", t, Text("when"));
+ return Texts("({ ", compile_declaration(t, Text("when")), ";\n",
+ compile_statement(when_env, WrapAST(ast, When, .subject = original->subject, .clauses = new_clauses,
+ .else_body = else_body)),
+ "when; })");
+}
diff --git a/src/compile/whens.h b/src/compile/whens.h
new file mode 100644
index 00000000..473124d5
--- /dev/null
+++ b/src/compile/whens.h
@@ -0,0 +1,9 @@
+// This file defines how to compile 'when' statements/expressions
+#pragma once
+
+#include "../ast.h"
+#include "../environment.h"
+#include "../stdlib/datatypes.h"
+
+Text_t compile_when_statement(env_t *env, ast_t *ast);
+Text_t compile_when_expression(env_t *env, ast_t *ast);