aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2024-05-24 00:03:46 -0400
committerBruce Hill <bruce@bruce-hill.com>2024-05-24 00:03:46 -0400
commit308946e794f05da9f7010797f5911bcf4e131c3e (patch)
tree8ab109d49a2b0314faacc577e21de1c60e227788
parentbf3cdc3dfa4dcd0d17f182fe875d718cb4a4272f (diff)
Switch optional detection to use 'when .. is @..' instead of 'if .. :=
..'; also fixed a bug with stack memory in doctests
-rw-r--r--compile.c126
-rw-r--r--parse.c40
-rw-r--r--typecheck.c39
3 files changed, 94 insertions, 111 deletions
diff --git a/compile.c b/compile.c
index f3a13309..8a42527b 100644
--- a/compile.c
+++ b/compile.c
@@ -173,13 +173,31 @@ CORD compile_statement(env_t *env, ast_t *ast)
{
switch (ast->tag) {
case When: {
+ // Typecheck to verify exhaustiveness:
+ type_t *result_t = get_type(env, ast);
+ (void)result_t;
+
auto when = Match(ast, When);
type_t *subject_t = get_type(env, when->subject);
+
+ if (subject_t->tag == PointerType) {
+ ast_t *var = when->clauses->args->ast;
+ CORD var_code = compile(env, var);
+ env_t *non_null_scope = fresh_scope(env);
+ auto ptr = Match(subject_t, PointerType);
+ type_t *non_optional_t = Type(PointerType, .pointed=ptr->pointed, .is_stack=ptr->is_stack,
+ .is_readonly=ptr->is_readonly, .is_optional=false);
+ set_binding(non_null_scope, Match(var, Var)->name, new(binding_t, .type=non_optional_t, .code=var_code));
+ return CORD_all(
+ "{\n",
+ compile_declaration(env, subject_t, var_code), " = ", compile(env, when->subject), ";\n"
+ "if (", var_code, ")\n", compile_statement(non_null_scope, when->clauses->body),
+ "\nelse\n", compile_statement(env, when->else_body), "\n}");
+ }
+
auto enum_t = Match(subject_t, EnumType);
CORD code = CORD_all("{ ", compile_type(env, subject_t), " subject = ", compile(env, when->subject), ";\n"
"switch (subject.$tag) {");
- type_t *result_t = get_type(env, ast);
- (void)result_t;
for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
const char *clause_tag_name = Match(clause->tag_name, Var)->name;
code = CORD_all(code, "case ", enum_t->env->file_prefix, enum_t->name, "$tag$", clause_tag_name, ": {\n");
@@ -252,7 +270,7 @@ CORD compile_statement(env_t *env, ast_t *ast)
CORD var = CORD_all("$", Match(decl->var, Var)->name);
return CORD_asprintf(
"%r;\n"
- "test(({ %s = %r; &%r;}), %r, %r, %r, %ld, %ld);\n",
+ "test(({ %r = %r; &%r;}), %r, %r, %r, %ld, %ld);\n",
compile_declaration(env, get_type(env, decl->value), var),
var,
compile(env, decl->value),
@@ -301,7 +319,8 @@ CORD compile_statement(env_t *env, ast_t *ast)
for (ast_list_t *target = assign->targets; target; target = target->next)
code = CORD_all(code, compile_assignment(env, target->ast, CORD_asprintf("$%ld", i++)));
- CORD_appendf(&code, "&$1; }), %r, %r, %r, %ld, %ld);",
+ CORD_appendf(&code, "(%r[1]){$1}; }), %r, %r, %r, %ld, %ld);",
+ compile_type(env, get_type(env, assign->targets->ast)),
compile_type_info(env, get_type(env, assign->targets->ast)),
compile(env, WrapAST(test->expr, TextLiteral, .cord=test->output)),
compile(env, WrapAST(test->expr, TextLiteral, .cord=test->expr->file->filename)),
@@ -318,8 +337,8 @@ CORD compile_statement(env_t *env, ast_t *ast)
(int64_t)(test->expr->end - test->expr->file->text));
} else {
return CORD_asprintf(
- "test(({ %r = %r; &expr; }), %r, %r, %r, %ld, %ld);",
- compile_declaration(env, expr_t, "expr"),
+ "test((%r[1]){%r}, %r, %r, %r, %ld, %ld);",
+ compile_type(env, expr_t),
compile(env, test->expr),
compile_type_info(env, expr_t),
compile(env, WrapAST(test->expr, TextLiteral, .cord=output)),
@@ -737,48 +756,21 @@ CORD compile_statement(env_t *env, ast_t *ast)
}
case If: {
auto if_ = Match(ast, If);
- if (if_->condition->tag == Declare) {
- auto decl = Match(if_->condition, Declare);
- env_t *true_scope = fresh_scope(env);
- const char *name = Match(decl->var, Var)->name;
- CORD var_code = CORD_cat(env->scope_prefix ? env->scope_prefix : "$", name);
- type_t *var_t = get_type(env, decl->value);
- if (var_t->tag == PointerType) {
- auto ptr = Match(var_t, PointerType);
- if (!ptr->is_optional)
- code_err(if_->condition, "This pointer will always be non-null, so it should not be used in a conditional.");
- var_t = Type(PointerType, .pointed=ptr->pointed, .is_optional=false, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly);
- } else {
- code_err(if_->condition, "Only optional pointer types can be used in 'if var := ...' statements (this is a %T)", var_t);
- }
- set_binding(true_scope, name, new(binding_t, .type=var_t, .code=var_code));
- CORD code = CORD_all("{\n",
- compile_type(env, var_t), " ", var_code, " = ", compile(env, decl->value), ";\n"
- "if (", var_code, ") ", compile_statement(true_scope, if_->body));
- if (if_->else_body)
- code = CORD_all(code, "\nelse ", compile_statement(env, if_->else_body));
- code = CORD_cat(code, "\n}");
- return code;
- } else {
- type_t *cond_t = get_type(env, if_->condition);
- if (cond_t->tag == PointerType) {
- if (!Match(cond_t, PointerType)->is_optional)
- code_err(if_->condition, "This pointer will always be non-null, so it should not be used in a conditional.");
- } else if (cond_t->tag != BoolType) {
- code_err(if_->condition, "Only boolean values and optional pointers can be used in conditionals (this is a %T)", cond_t);
- }
- CORD code;
- CORD_sprintf(&code, "if (%r) %r", compile(env, if_->condition), compile_statement(env, if_->body));
- if (if_->else_body)
- code = CORD_all(code, "\nelse ", compile_statement(env, if_->else_body));
- return code;
+ type_t *cond_t = get_type(env, if_->condition);
+ if (cond_t->tag == PointerType) {
+ if (!Match(cond_t, PointerType)->is_optional)
+ code_err(if_->condition, "This pointer will always be non-null, so it should not be used in a conditional.");
+ } else if (cond_t->tag != BoolType) {
+ code_err(if_->condition, "Only boolean values and optional pointers can be used in conditionals (this is a %T)", cond_t);
}
+ CORD code;
+ CORD_sprintf(&code, "if (%r) %r", compile(env, if_->condition), compile_statement(env, if_->body));
+ if (if_->else_body)
+ code = CORD_all(code, "\nelse ", compile_statement(env, if_->else_body));
+ return code;
}
case Block: {
ast_list_t *stmts = Match(ast, Block)->statements;
- if (stmts && !stmts->next)
- return compile_statement(env, stmts->ast);
-
CORD code = "{\n";
env = fresh_scope(env);
for (ast_list_t *stmt = stmts; stmt; stmt = stmt->next)
@@ -1019,7 +1011,7 @@ CORD compile(env_t *env, ast_t *ast)
switch (ast->tag) {
case Nil: {
type_t *t = parse_type_ast(env, Match(ast, Nil)->type);
- return CORD_all("((", compile_type(env, t), "*)NULL)");
+ return CORD_all("((", compile_type(env, t), ")NULL)");
}
case Bool: return Match(ast, Bool)->b ? "yes" : "no";
case Var: {
@@ -1771,39 +1763,17 @@ CORD compile(env_t *env, ast_t *ast)
if (t->tag == VoidType || t->tag == AbortType)
code_err(ast, "This expression has a %T type, but it needs to have a real value", t);
- if (if_->condition->tag == Declare) {
- CORD condition = Match(Match(if_->condition, Declare)->var, Var)->name;
- CORD decl = compile_statement(env, if_->condition);
- env_t *true_scope = fresh_scope(env);
- prebind_statement(true_scope, if_->condition);
- bind_statement(true_scope, if_->condition);
- type_t *true_type = get_type(true_scope, if_->body);
- type_t *false_type = get_type(env, if_->else_body);
- if (true_type->tag == AbortType) {
- return CORD_all("({ ", decl, "\nif (", condition, ") ", compile_statement(true_scope, if_->body), " ",
- compile(env, if_->else_body), "; })");
- } else if (false_type->tag == AbortType) {
- return CORD_all("({ ", decl, "\nif (!(", condition, ")) ", compile_statement(env, if_->else_body), " ",
- compile(true_scope, if_->body), "; })");
-
- } else {
- return CORD_all("({ ", decl, "\n(", condition, ") ? ",
- compile(true_scope, if_->body), " : ",
- compile(env, if_->else_body), "; })");
- }
- } else {
- type_t *true_type = get_type(env, if_->body);
- type_t *false_type = get_type(env, if_->else_body);
- if (true_type->tag == AbortType)
- return CORD_all("({ if (", compile(env, if_->condition), ") ", compile_statement(env, if_->body),
- "\n", compile(env, if_->else_body), "; })");
- else if (false_type->tag == AbortType)
- return CORD_all("({ if (!(", compile(env, if_->condition), ")) ", compile_statement(env, if_->else_body),
- "\n", compile(env, if_->body), "; })");
- else
- return CORD_all("((", compile(env, if_->condition), ") ? ",
- compile(env, if_->body), " : ", compile(env, if_->else_body), ")");
- }
+ type_t *true_type = get_type(env, if_->body);
+ type_t *false_type = get_type(env, if_->else_body);
+ if (true_type->tag == AbortType)
+ return CORD_all("({ if (", compile(env, if_->condition), ") ", compile_statement(env, if_->body),
+ "\n", compile(env, if_->else_body), "; })");
+ else if (false_type->tag == AbortType)
+ return CORD_all("({ if (!(", compile(env, if_->condition), ")) ", compile_statement(env, if_->else_body),
+ "\n", compile(env, if_->body), "; })");
+ else
+ return CORD_all("((", compile(env, if_->condition), ") ? ",
+ compile(env, if_->body), " : ", compile(env, if_->else_body), ")");
}
case Reduction: {
auto reduction = Match(ast, Reduction);
diff --git a/parse.c b/parse.c
index fbb82871..701cdc1c 100644
--- a/parse.c
+++ b/parse.c
@@ -875,9 +875,8 @@ PARSER(parse_if) {
if (!match_word(&pos, "if"))
return NULL;
- ast_t *condition = optional(ctx, &pos, parse_declaration);
- if (!condition) condition = expect(ctx, start, &pos, parse_expr,
- "I expected to find an expression for this 'if'");
+ ast_t *condition = expect(ctx, start, &pos, parse_expr,
+ "I expected to find a condition for this 'if'");
ast_t *body = expect(ctx, start, &pos, parse_block, "I expected a body for this 'if' statement");
@@ -913,21 +912,30 @@ PARSER(parse_when) {
while (get_indent(ctx, tmp) == starting_indent && match_word(&tmp, "is")) {
pos = tmp;
spaces(&pos);
- ast_t *tag_name = expect(ctx, start, &pos, parse_var, "I expected a tag name here");
- spaces(&pos);
- ast_list_t *args = NULL;
- if (match(&pos, "(")) {
- for (;;) {
- whitespace(&pos);
- ast_t *arg = optional(ctx, &pos, parse_var);
- if (!arg) break;
- args = new(ast_list_t, .ast=arg, .next=args);
+ ast_t *tag_name;
+ ast_list_t *args;
+ if (match(&pos, "@")) {
+ tag_name = NewAST(ctx->file, pos-1, pos, Var, .name="@");
+ spaces(&pos);
+ ast_t *arg = optional(ctx, &pos, parse_var);
+ args = arg ? new(ast_list_t, .ast=arg) : NULL;
+ } else {
+ tag_name = expect(ctx, start, &pos, parse_var, "I expected a tag name here");
+ spaces(&pos);
+ args = NULL;
+ if (match(&pos, "(")) {
+ for (;;) {
+ whitespace(&pos);
+ ast_t *arg = optional(ctx, &pos, parse_var);
+ if (!arg) break;
+ args = new(ast_list_t, .ast=arg, .next=args);
+ whitespace(&pos);
+ if (!match(&pos, ",")) break;
+ }
whitespace(&pos);
- if (!match(&pos, ",")) break;
+ expect_closing(ctx, &pos, ")", "I was expecting a ')' to finish this pattern's arguments");
+ REVERSE_LIST(args);
}
- whitespace(&pos);
- expect_closing(ctx, &pos, ")", "I was expecting a ')' to finish this pattern's arguments");
- REVERSE_LIST(args);
}
ast_t *body = expect(ctx, start, &pos, parse_block, "I expected a body for this 'when' clause");
diff --git a/typecheck.c b/typecheck.c
index 9964c9d9..16841e8f 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -845,22 +845,7 @@ type_t *get_type(env_t *env, ast_t *ast)
case If: {
auto if_ = Match(ast, If);
- type_t *true_t;
- if (if_->condition->tag == Declare) {
- auto decl = Match(if_->condition, Declare);
- env_t *scope = fresh_scope(env);
- type_t *var_t = get_type(env, decl->value);
- if (var_t->tag == PointerType) {
- auto ptr = Match(var_t, PointerType);
- var_t = Type(PointerType, .pointed=ptr->pointed, .is_optional=false, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly);
- }
- CORD var = Match(decl->var, Var)->name;
- set_binding(scope, CORD_to_const_char_star(var), new(binding_t, .type=var_t));
- true_t = get_type(scope, if_->body);
- } else {
- true_t = get_type(env, if_->body);
- }
-
+ type_t *true_t = get_type(env, if_->body);
if (if_->else_body) {
type_t *false_t = get_type(env, if_->else_body);
type_t *t_either = type_or_type(true_t, false_t);
@@ -877,8 +862,28 @@ type_t *get_type(env_t *env, ast_t *ast)
case When: {
auto when = Match(ast, When);
type_t *subject_t = get_type(env, when->subject);
+ if (subject_t->tag == PointerType) {
+ if (!Match(subject_t, PointerType)->is_optional)
+ code_err(when->subject, "This %T pointer type is not optional, so this 'when' statement is tautological", subject_t);
+
+ bool handled_at = false;
+ for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
+ const char *tag_name = Match(clause->tag_name, Var)->name;
+ if (!streq(tag_name, "@"))
+ code_err(clause->tag_name, "'when' clauses on optional pointers only support @var, not tags like '%s'", tag_name);
+ if (handled_at)
+ code_err(clause->tag_name, "This 'when' statement has already handled the case of non-null pointers!");
+ handled_at = true;
+ }
+ if (!handled_at)
+ code_err(ast, "This 'when' statement doesn't handle non-null pointers");
+ if (!when->else_body)
+ code_err(ast, "This 'when' statement doesn't handle null pointers");
+ return Type(VoidType);
+ }
+
if (subject_t->tag != EnumType)
- code_err(when->subject, "'when' statements are only for enum types, not %T", subject_t);
+ code_err(when->subject, "'when' statements are only for enum types and optional pointers, not %T", subject_t);
tag_t * const tags = Match(subject_t, EnumType)->tags;