Switch optional detection to use 'when .. is @..' instead of 'if .. :=

..'; also fixed a bug with stack memory in doctests
This commit is contained in:
Bruce Hill 2024-05-24 00:03:46 -04:00
parent bf3cdc3dfa
commit 308946e794
3 changed files with 94 additions and 111 deletions

126
compile.c
View File

@ -173,13 +173,31 @@ CORD compile_statement(env_t *env, ast_t *ast)
{
switch (ast->tag) {
case When: {
// Typecheck to verify exhaustiveness:
type_t *result_t = get_type(env, ast);
(void)result_t;
auto when = Match(ast, When);
type_t *subject_t = get_type(env, when->subject);
if (subject_t->tag == PointerType) {
ast_t *var = when->clauses->args->ast;
CORD var_code = compile(env, var);
env_t *non_null_scope = fresh_scope(env);
auto ptr = Match(subject_t, PointerType);
type_t *non_optional_t = Type(PointerType, .pointed=ptr->pointed, .is_stack=ptr->is_stack,
.is_readonly=ptr->is_readonly, .is_optional=false);
set_binding(non_null_scope, Match(var, Var)->name, new(binding_t, .type=non_optional_t, .code=var_code));
return CORD_all(
"{\n",
compile_declaration(env, subject_t, var_code), " = ", compile(env, when->subject), ";\n"
"if (", var_code, ")\n", compile_statement(non_null_scope, when->clauses->body),
"\nelse\n", compile_statement(env, when->else_body), "\n}");
}
auto enum_t = Match(subject_t, EnumType);
CORD code = CORD_all("{ ", compile_type(env, subject_t), " subject = ", compile(env, when->subject), ";\n"
"switch (subject.$tag) {");
type_t *result_t = get_type(env, ast);
(void)result_t;
for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
const char *clause_tag_name = Match(clause->tag_name, Var)->name;
code = CORD_all(code, "case ", enum_t->env->file_prefix, enum_t->name, "$tag$", clause_tag_name, ": {\n");
@ -252,7 +270,7 @@ CORD compile_statement(env_t *env, ast_t *ast)
CORD var = CORD_all("$", Match(decl->var, Var)->name);
return CORD_asprintf(
"%r;\n"
"test(({ %s = %r; &%r;}), %r, %r, %r, %ld, %ld);\n",
"test(({ %r = %r; &%r;}), %r, %r, %r, %ld, %ld);\n",
compile_declaration(env, get_type(env, decl->value), var),
var,
compile(env, decl->value),
@ -301,7 +319,8 @@ CORD compile_statement(env_t *env, ast_t *ast)
for (ast_list_t *target = assign->targets; target; target = target->next)
code = CORD_all(code, compile_assignment(env, target->ast, CORD_asprintf("$%ld", i++)));
CORD_appendf(&code, "&$1; }), %r, %r, %r, %ld, %ld);",
CORD_appendf(&code, "(%r[1]){$1}; }), %r, %r, %r, %ld, %ld);",
compile_type(env, get_type(env, assign->targets->ast)),
compile_type_info(env, get_type(env, assign->targets->ast)),
compile(env, WrapAST(test->expr, TextLiteral, .cord=test->output)),
compile(env, WrapAST(test->expr, TextLiteral, .cord=test->expr->file->filename)),
@ -318,8 +337,8 @@ CORD compile_statement(env_t *env, ast_t *ast)
(int64_t)(test->expr->end - test->expr->file->text));
} else {
return CORD_asprintf(
"test(({ %r = %r; &expr; }), %r, %r, %r, %ld, %ld);",
compile_declaration(env, expr_t, "expr"),
"test((%r[1]){%r}, %r, %r, %r, %ld, %ld);",
compile_type(env, expr_t),
compile(env, test->expr),
compile_type_info(env, expr_t),
compile(env, WrapAST(test->expr, TextLiteral, .cord=output)),
@ -737,48 +756,21 @@ CORD compile_statement(env_t *env, ast_t *ast)
}
case If: {
auto if_ = Match(ast, If);
if (if_->condition->tag == Declare) {
auto decl = Match(if_->condition, Declare);
env_t *true_scope = fresh_scope(env);
const char *name = Match(decl->var, Var)->name;
CORD var_code = CORD_cat(env->scope_prefix ? env->scope_prefix : "$", name);
type_t *var_t = get_type(env, decl->value);
if (var_t->tag == PointerType) {
auto ptr = Match(var_t, PointerType);
if (!ptr->is_optional)
code_err(if_->condition, "This pointer will always be non-null, so it should not be used in a conditional.");
var_t = Type(PointerType, .pointed=ptr->pointed, .is_optional=false, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly);
} else {
code_err(if_->condition, "Only optional pointer types can be used in 'if var := ...' statements (this is a %T)", var_t);
}
set_binding(true_scope, name, new(binding_t, .type=var_t, .code=var_code));
CORD code = CORD_all("{\n",
compile_type(env, var_t), " ", var_code, " = ", compile(env, decl->value), ";\n"
"if (", var_code, ") ", compile_statement(true_scope, if_->body));
if (if_->else_body)
code = CORD_all(code, "\nelse ", compile_statement(env, if_->else_body));
code = CORD_cat(code, "\n}");
return code;
} else {
type_t *cond_t = get_type(env, if_->condition);
if (cond_t->tag == PointerType) {
if (!Match(cond_t, PointerType)->is_optional)
code_err(if_->condition, "This pointer will always be non-null, so it should not be used in a conditional.");
} else if (cond_t->tag != BoolType) {
code_err(if_->condition, "Only boolean values and optional pointers can be used in conditionals (this is a %T)", cond_t);
}
CORD code;
CORD_sprintf(&code, "if (%r) %r", compile(env, if_->condition), compile_statement(env, if_->body));
if (if_->else_body)
code = CORD_all(code, "\nelse ", compile_statement(env, if_->else_body));
return code;
type_t *cond_t = get_type(env, if_->condition);
if (cond_t->tag == PointerType) {
if (!Match(cond_t, PointerType)->is_optional)
code_err(if_->condition, "This pointer will always be non-null, so it should not be used in a conditional.");
} else if (cond_t->tag != BoolType) {
code_err(if_->condition, "Only boolean values and optional pointers can be used in conditionals (this is a %T)", cond_t);
}
CORD code;
CORD_sprintf(&code, "if (%r) %r", compile(env, if_->condition), compile_statement(env, if_->body));
if (if_->else_body)
code = CORD_all(code, "\nelse ", compile_statement(env, if_->else_body));
return code;
}
case Block: {
ast_list_t *stmts = Match(ast, Block)->statements;
if (stmts && !stmts->next)
return compile_statement(env, stmts->ast);
CORD code = "{\n";
env = fresh_scope(env);
for (ast_list_t *stmt = stmts; stmt; stmt = stmt->next)
@ -1019,7 +1011,7 @@ CORD compile(env_t *env, ast_t *ast)
switch (ast->tag) {
case Nil: {
type_t *t = parse_type_ast(env, Match(ast, Nil)->type);
return CORD_all("((", compile_type(env, t), "*)NULL)");
return CORD_all("((", compile_type(env, t), ")NULL)");
}
case Bool: return Match(ast, Bool)->b ? "yes" : "no";
case Var: {
@ -1771,39 +1763,17 @@ CORD compile(env_t *env, ast_t *ast)
if (t->tag == VoidType || t->tag == AbortType)
code_err(ast, "This expression has a %T type, but it needs to have a real value", t);
if (if_->condition->tag == Declare) {
CORD condition = Match(Match(if_->condition, Declare)->var, Var)->name;
CORD decl = compile_statement(env, if_->condition);
env_t *true_scope = fresh_scope(env);
prebind_statement(true_scope, if_->condition);
bind_statement(true_scope, if_->condition);
type_t *true_type = get_type(true_scope, if_->body);
type_t *false_type = get_type(env, if_->else_body);
if (true_type->tag == AbortType) {
return CORD_all("({ ", decl, "\nif (", condition, ") ", compile_statement(true_scope, if_->body), " ",
compile(env, if_->else_body), "; })");
} else if (false_type->tag == AbortType) {
return CORD_all("({ ", decl, "\nif (!(", condition, ")) ", compile_statement(env, if_->else_body), " ",
compile(true_scope, if_->body), "; })");
} else {
return CORD_all("({ ", decl, "\n(", condition, ") ? ",
compile(true_scope, if_->body), " : ",
compile(env, if_->else_body), "; })");
}
} else {
type_t *true_type = get_type(env, if_->body);
type_t *false_type = get_type(env, if_->else_body);
if (true_type->tag == AbortType)
return CORD_all("({ if (", compile(env, if_->condition), ") ", compile_statement(env, if_->body),
"\n", compile(env, if_->else_body), "; })");
else if (false_type->tag == AbortType)
return CORD_all("({ if (!(", compile(env, if_->condition), ")) ", compile_statement(env, if_->else_body),
"\n", compile(env, if_->body), "; })");
else
return CORD_all("((", compile(env, if_->condition), ") ? ",
compile(env, if_->body), " : ", compile(env, if_->else_body), ")");
}
type_t *true_type = get_type(env, if_->body);
type_t *false_type = get_type(env, if_->else_body);
if (true_type->tag == AbortType)
return CORD_all("({ if (", compile(env, if_->condition), ") ", compile_statement(env, if_->body),
"\n", compile(env, if_->else_body), "; })");
else if (false_type->tag == AbortType)
return CORD_all("({ if (!(", compile(env, if_->condition), ")) ", compile_statement(env, if_->else_body),
"\n", compile(env, if_->body), "; })");
else
return CORD_all("((", compile(env, if_->condition), ") ? ",
compile(env, if_->body), " : ", compile(env, if_->else_body), ")");
}
case Reduction: {
auto reduction = Match(ast, Reduction);

40
parse.c
View File

@ -875,9 +875,8 @@ PARSER(parse_if) {
if (!match_word(&pos, "if"))
return NULL;
ast_t *condition = optional(ctx, &pos, parse_declaration);
if (!condition) condition = expect(ctx, start, &pos, parse_expr,
"I expected to find an expression for this 'if'");
ast_t *condition = expect(ctx, start, &pos, parse_expr,
"I expected to find a condition for this 'if'");
ast_t *body = expect(ctx, start, &pos, parse_block, "I expected a body for this 'if' statement");
@ -913,21 +912,30 @@ PARSER(parse_when) {
while (get_indent(ctx, tmp) == starting_indent && match_word(&tmp, "is")) {
pos = tmp;
spaces(&pos);
ast_t *tag_name = expect(ctx, start, &pos, parse_var, "I expected a tag name here");
spaces(&pos);
ast_list_t *args = NULL;
if (match(&pos, "(")) {
for (;;) {
ast_t *tag_name;
ast_list_t *args;
if (match(&pos, "@")) {
tag_name = NewAST(ctx->file, pos-1, pos, Var, .name="@");
spaces(&pos);
ast_t *arg = optional(ctx, &pos, parse_var);
args = arg ? new(ast_list_t, .ast=arg) : NULL;
} else {
tag_name = expect(ctx, start, &pos, parse_var, "I expected a tag name here");
spaces(&pos);
args = NULL;
if (match(&pos, "(")) {
for (;;) {
whitespace(&pos);
ast_t *arg = optional(ctx, &pos, parse_var);
if (!arg) break;
args = new(ast_list_t, .ast=arg, .next=args);
whitespace(&pos);
if (!match(&pos, ",")) break;
}
whitespace(&pos);
ast_t *arg = optional(ctx, &pos, parse_var);
if (!arg) break;
args = new(ast_list_t, .ast=arg, .next=args);
whitespace(&pos);
if (!match(&pos, ",")) break;
expect_closing(ctx, &pos, ")", "I was expecting a ')' to finish this pattern's arguments");
REVERSE_LIST(args);
}
whitespace(&pos);
expect_closing(ctx, &pos, ")", "I was expecting a ')' to finish this pattern's arguments");
REVERSE_LIST(args);
}
ast_t *body = expect(ctx, start, &pos, parse_block, "I expected a body for this 'when' clause");

View File

@ -845,22 +845,7 @@ type_t *get_type(env_t *env, ast_t *ast)
case If: {
auto if_ = Match(ast, If);
type_t *true_t;
if (if_->condition->tag == Declare) {
auto decl = Match(if_->condition, Declare);
env_t *scope = fresh_scope(env);
type_t *var_t = get_type(env, decl->value);
if (var_t->tag == PointerType) {
auto ptr = Match(var_t, PointerType);
var_t = Type(PointerType, .pointed=ptr->pointed, .is_optional=false, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly);
}
CORD var = Match(decl->var, Var)->name;
set_binding(scope, CORD_to_const_char_star(var), new(binding_t, .type=var_t));
true_t = get_type(scope, if_->body);
} else {
true_t = get_type(env, if_->body);
}
type_t *true_t = get_type(env, if_->body);
if (if_->else_body) {
type_t *false_t = get_type(env, if_->else_body);
type_t *t_either = type_or_type(true_t, false_t);
@ -877,8 +862,28 @@ type_t *get_type(env_t *env, ast_t *ast)
case When: {
auto when = Match(ast, When);
type_t *subject_t = get_type(env, when->subject);
if (subject_t->tag == PointerType) {
if (!Match(subject_t, PointerType)->is_optional)
code_err(when->subject, "This %T pointer type is not optional, so this 'when' statement is tautological", subject_t);
bool handled_at = false;
for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
const char *tag_name = Match(clause->tag_name, Var)->name;
if (!streq(tag_name, "@"))
code_err(clause->tag_name, "'when' clauses on optional pointers only support @var, not tags like '%s'", tag_name);
if (handled_at)
code_err(clause->tag_name, "This 'when' statement has already handled the case of non-null pointers!");
handled_at = true;
}
if (!handled_at)
code_err(ast, "This 'when' statement doesn't handle non-null pointers");
if (!when->else_body)
code_err(ast, "This 'when' statement doesn't handle null pointers");
return Type(VoidType);
}
if (subject_t->tag != EnumType)
code_err(when->subject, "'when' statements are only for enum types, not %T", subject_t);
code_err(when->subject, "'when' statements are only for enum types and optional pointers, not %T", subject_t);
tag_t * const tags = Match(subject_t, EnumType)->tags;