From 308946e794f05da9f7010797f5911bcf4e131c3e Mon Sep 17 00:00:00 2001 From: Bruce Hill Date: Fri, 24 May 2024 00:03:46 -0400 Subject: Switch optional detection to use 'when .. is @..' instead of 'if .. := ..'; also fixed a bug with stack memory in doctests --- compile.c | 126 +++++++++++++++++++++++------------------------------------- parse.c | 40 +++++++++++-------- typecheck.c | 39 +++++++++++-------- 3 files changed, 94 insertions(+), 111 deletions(-) diff --git a/compile.c b/compile.c index f3a13309..8a42527b 100644 --- a/compile.c +++ b/compile.c @@ -173,13 +173,31 @@ CORD compile_statement(env_t *env, ast_t *ast) { switch (ast->tag) { case When: { + // Typecheck to verify exhaustiveness: + type_t *result_t = get_type(env, ast); + (void)result_t; + auto when = Match(ast, When); type_t *subject_t = get_type(env, when->subject); + + if (subject_t->tag == PointerType) { + ast_t *var = when->clauses->args->ast; + CORD var_code = compile(env, var); + env_t *non_null_scope = fresh_scope(env); + auto ptr = Match(subject_t, PointerType); + type_t *non_optional_t = Type(PointerType, .pointed=ptr->pointed, .is_stack=ptr->is_stack, + .is_readonly=ptr->is_readonly, .is_optional=false); + set_binding(non_null_scope, Match(var, Var)->name, new(binding_t, .type=non_optional_t, .code=var_code)); + return CORD_all( + "{\n", + compile_declaration(env, subject_t, var_code), " = ", compile(env, when->subject), ";\n" + "if (", var_code, ")\n", compile_statement(non_null_scope, when->clauses->body), + "\nelse\n", compile_statement(env, when->else_body), "\n}"); + } + auto enum_t = Match(subject_t, EnumType); CORD code = CORD_all("{ ", compile_type(env, subject_t), " subject = ", compile(env, when->subject), ";\n" "switch (subject.$tag) {"); - type_t *result_t = get_type(env, ast); - (void)result_t; for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { const char *clause_tag_name = Match(clause->tag_name, Var)->name; code = CORD_all(code, "case ", enum_t->env->file_prefix, enum_t->name, "$tag$", clause_tag_name, ": {\n"); @@ -252,7 +270,7 @@ CORD compile_statement(env_t *env, ast_t *ast) CORD var = CORD_all("$", Match(decl->var, Var)->name); return CORD_asprintf( "%r;\n" - "test(({ %s = %r; &%r;}), %r, %r, %r, %ld, %ld);\n", + "test(({ %r = %r; &%r;}), %r, %r, %r, %ld, %ld);\n", compile_declaration(env, get_type(env, decl->value), var), var, compile(env, decl->value), @@ -301,7 +319,8 @@ CORD compile_statement(env_t *env, ast_t *ast) for (ast_list_t *target = assign->targets; target; target = target->next) code = CORD_all(code, compile_assignment(env, target->ast, CORD_asprintf("$%ld", i++))); - CORD_appendf(&code, "&$1; }), %r, %r, %r, %ld, %ld);", + CORD_appendf(&code, "(%r[1]){$1}; }), %r, %r, %r, %ld, %ld);", + compile_type(env, get_type(env, assign->targets->ast)), compile_type_info(env, get_type(env, assign->targets->ast)), compile(env, WrapAST(test->expr, TextLiteral, .cord=test->output)), compile(env, WrapAST(test->expr, TextLiteral, .cord=test->expr->file->filename)), @@ -318,8 +337,8 @@ CORD compile_statement(env_t *env, ast_t *ast) (int64_t)(test->expr->end - test->expr->file->text)); } else { return CORD_asprintf( - "test(({ %r = %r; &expr; }), %r, %r, %r, %ld, %ld);", - compile_declaration(env, expr_t, "expr"), + "test((%r[1]){%r}, %r, %r, %r, %ld, %ld);", + compile_type(env, expr_t), compile(env, test->expr), compile_type_info(env, expr_t), compile(env, WrapAST(test->expr, TextLiteral, .cord=output)), @@ -737,48 +756,21 @@ CORD compile_statement(env_t *env, ast_t *ast) } case If: { auto if_ = Match(ast, If); - if (if_->condition->tag == Declare) { - auto decl = Match(if_->condition, Declare); - env_t *true_scope = fresh_scope(env); - const char *name = Match(decl->var, Var)->name; - CORD var_code = CORD_cat(env->scope_prefix ? env->scope_prefix : "$", name); - type_t *var_t = get_type(env, decl->value); - if (var_t->tag == PointerType) { - auto ptr = Match(var_t, PointerType); - if (!ptr->is_optional) - code_err(if_->condition, "This pointer will always be non-null, so it should not be used in a conditional."); - var_t = Type(PointerType, .pointed=ptr->pointed, .is_optional=false, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly); - } else { - code_err(if_->condition, "Only optional pointer types can be used in 'if var := ...' statements (this is a %T)", var_t); - } - set_binding(true_scope, name, new(binding_t, .type=var_t, .code=var_code)); - CORD code = CORD_all("{\n", - compile_type(env, var_t), " ", var_code, " = ", compile(env, decl->value), ";\n" - "if (", var_code, ") ", compile_statement(true_scope, if_->body)); - if (if_->else_body) - code = CORD_all(code, "\nelse ", compile_statement(env, if_->else_body)); - code = CORD_cat(code, "\n}"); - return code; - } else { - type_t *cond_t = get_type(env, if_->condition); - if (cond_t->tag == PointerType) { - if (!Match(cond_t, PointerType)->is_optional) - code_err(if_->condition, "This pointer will always be non-null, so it should not be used in a conditional."); - } else if (cond_t->tag != BoolType) { - code_err(if_->condition, "Only boolean values and optional pointers can be used in conditionals (this is a %T)", cond_t); - } - CORD code; - CORD_sprintf(&code, "if (%r) %r", compile(env, if_->condition), compile_statement(env, if_->body)); - if (if_->else_body) - code = CORD_all(code, "\nelse ", compile_statement(env, if_->else_body)); - return code; + type_t *cond_t = get_type(env, if_->condition); + if (cond_t->tag == PointerType) { + if (!Match(cond_t, PointerType)->is_optional) + code_err(if_->condition, "This pointer will always be non-null, so it should not be used in a conditional."); + } else if (cond_t->tag != BoolType) { + code_err(if_->condition, "Only boolean values and optional pointers can be used in conditionals (this is a %T)", cond_t); } + CORD code; + CORD_sprintf(&code, "if (%r) %r", compile(env, if_->condition), compile_statement(env, if_->body)); + if (if_->else_body) + code = CORD_all(code, "\nelse ", compile_statement(env, if_->else_body)); + return code; } case Block: { ast_list_t *stmts = Match(ast, Block)->statements; - if (stmts && !stmts->next) - return compile_statement(env, stmts->ast); - CORD code = "{\n"; env = fresh_scope(env); for (ast_list_t *stmt = stmts; stmt; stmt = stmt->next) @@ -1019,7 +1011,7 @@ CORD compile(env_t *env, ast_t *ast) switch (ast->tag) { case Nil: { type_t *t = parse_type_ast(env, Match(ast, Nil)->type); - return CORD_all("((", compile_type(env, t), "*)NULL)"); + return CORD_all("((", compile_type(env, t), ")NULL)"); } case Bool: return Match(ast, Bool)->b ? "yes" : "no"; case Var: { @@ -1771,39 +1763,17 @@ CORD compile(env_t *env, ast_t *ast) if (t->tag == VoidType || t->tag == AbortType) code_err(ast, "This expression has a %T type, but it needs to have a real value", t); - if (if_->condition->tag == Declare) { - CORD condition = Match(Match(if_->condition, Declare)->var, Var)->name; - CORD decl = compile_statement(env, if_->condition); - env_t *true_scope = fresh_scope(env); - prebind_statement(true_scope, if_->condition); - bind_statement(true_scope, if_->condition); - type_t *true_type = get_type(true_scope, if_->body); - type_t *false_type = get_type(env, if_->else_body); - if (true_type->tag == AbortType) { - return CORD_all("({ ", decl, "\nif (", condition, ") ", compile_statement(true_scope, if_->body), " ", - compile(env, if_->else_body), "; })"); - } else if (false_type->tag == AbortType) { - return CORD_all("({ ", decl, "\nif (!(", condition, ")) ", compile_statement(env, if_->else_body), " ", - compile(true_scope, if_->body), "; })"); - - } else { - return CORD_all("({ ", decl, "\n(", condition, ") ? ", - compile(true_scope, if_->body), " : ", - compile(env, if_->else_body), "; })"); - } - } else { - type_t *true_type = get_type(env, if_->body); - type_t *false_type = get_type(env, if_->else_body); - if (true_type->tag == AbortType) - return CORD_all("({ if (", compile(env, if_->condition), ") ", compile_statement(env, if_->body), - "\n", compile(env, if_->else_body), "; })"); - else if (false_type->tag == AbortType) - return CORD_all("({ if (!(", compile(env, if_->condition), ")) ", compile_statement(env, if_->else_body), - "\n", compile(env, if_->body), "; })"); - else - return CORD_all("((", compile(env, if_->condition), ") ? ", - compile(env, if_->body), " : ", compile(env, if_->else_body), ")"); - } + type_t *true_type = get_type(env, if_->body); + type_t *false_type = get_type(env, if_->else_body); + if (true_type->tag == AbortType) + return CORD_all("({ if (", compile(env, if_->condition), ") ", compile_statement(env, if_->body), + "\n", compile(env, if_->else_body), "; })"); + else if (false_type->tag == AbortType) + return CORD_all("({ if (!(", compile(env, if_->condition), ")) ", compile_statement(env, if_->else_body), + "\n", compile(env, if_->body), "; })"); + else + return CORD_all("((", compile(env, if_->condition), ") ? ", + compile(env, if_->body), " : ", compile(env, if_->else_body), ")"); } case Reduction: { auto reduction = Match(ast, Reduction); diff --git a/parse.c b/parse.c index fbb82871..701cdc1c 100644 --- a/parse.c +++ b/parse.c @@ -875,9 +875,8 @@ PARSER(parse_if) { if (!match_word(&pos, "if")) return NULL; - ast_t *condition = optional(ctx, &pos, parse_declaration); - if (!condition) condition = expect(ctx, start, &pos, parse_expr, - "I expected to find an expression for this 'if'"); + ast_t *condition = expect(ctx, start, &pos, parse_expr, + "I expected to find a condition for this 'if'"); ast_t *body = expect(ctx, start, &pos, parse_block, "I expected a body for this 'if' statement"); @@ -913,21 +912,30 @@ PARSER(parse_when) { while (get_indent(ctx, tmp) == starting_indent && match_word(&tmp, "is")) { pos = tmp; spaces(&pos); - ast_t *tag_name = expect(ctx, start, &pos, parse_var, "I expected a tag name here"); - spaces(&pos); - ast_list_t *args = NULL; - if (match(&pos, "(")) { - for (;;) { - whitespace(&pos); - ast_t *arg = optional(ctx, &pos, parse_var); - if (!arg) break; - args = new(ast_list_t, .ast=arg, .next=args); + ast_t *tag_name; + ast_list_t *args; + if (match(&pos, "@")) { + tag_name = NewAST(ctx->file, pos-1, pos, Var, .name="@"); + spaces(&pos); + ast_t *arg = optional(ctx, &pos, parse_var); + args = arg ? new(ast_list_t, .ast=arg) : NULL; + } else { + tag_name = expect(ctx, start, &pos, parse_var, "I expected a tag name here"); + spaces(&pos); + args = NULL; + if (match(&pos, "(")) { + for (;;) { + whitespace(&pos); + ast_t *arg = optional(ctx, &pos, parse_var); + if (!arg) break; + args = new(ast_list_t, .ast=arg, .next=args); + whitespace(&pos); + if (!match(&pos, ",")) break; + } whitespace(&pos); - if (!match(&pos, ",")) break; + expect_closing(ctx, &pos, ")", "I was expecting a ')' to finish this pattern's arguments"); + REVERSE_LIST(args); } - whitespace(&pos); - expect_closing(ctx, &pos, ")", "I was expecting a ')' to finish this pattern's arguments"); - REVERSE_LIST(args); } ast_t *body = expect(ctx, start, &pos, parse_block, "I expected a body for this 'when' clause"); diff --git a/typecheck.c b/typecheck.c index 9964c9d9..16841e8f 100644 --- a/typecheck.c +++ b/typecheck.c @@ -845,22 +845,7 @@ type_t *get_type(env_t *env, ast_t *ast) case If: { auto if_ = Match(ast, If); - type_t *true_t; - if (if_->condition->tag == Declare) { - auto decl = Match(if_->condition, Declare); - env_t *scope = fresh_scope(env); - type_t *var_t = get_type(env, decl->value); - if (var_t->tag == PointerType) { - auto ptr = Match(var_t, PointerType); - var_t = Type(PointerType, .pointed=ptr->pointed, .is_optional=false, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly); - } - CORD var = Match(decl->var, Var)->name; - set_binding(scope, CORD_to_const_char_star(var), new(binding_t, .type=var_t)); - true_t = get_type(scope, if_->body); - } else { - true_t = get_type(env, if_->body); - } - + type_t *true_t = get_type(env, if_->body); if (if_->else_body) { type_t *false_t = get_type(env, if_->else_body); type_t *t_either = type_or_type(true_t, false_t); @@ -877,8 +862,28 @@ type_t *get_type(env_t *env, ast_t *ast) case When: { auto when = Match(ast, When); type_t *subject_t = get_type(env, when->subject); + if (subject_t->tag == PointerType) { + if (!Match(subject_t, PointerType)->is_optional) + code_err(when->subject, "This %T pointer type is not optional, so this 'when' statement is tautological", subject_t); + + bool handled_at = false; + for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { + const char *tag_name = Match(clause->tag_name, Var)->name; + if (!streq(tag_name, "@")) + code_err(clause->tag_name, "'when' clauses on optional pointers only support @var, not tags like '%s'", tag_name); + if (handled_at) + code_err(clause->tag_name, "This 'when' statement has already handled the case of non-null pointers!"); + handled_at = true; + } + if (!handled_at) + code_err(ast, "This 'when' statement doesn't handle non-null pointers"); + if (!when->else_body) + code_err(ast, "This 'when' statement doesn't handle null pointers"); + return Type(VoidType); + } + if (subject_t->tag != EnumType) - code_err(when->subject, "'when' statements are only for enum types, not %T", subject_t); + code_err(when->subject, "'when' statements are only for enum types and optional pointers, not %T", subject_t); tag_t * const tags = Match(subject_t, EnumType)->tags; -- cgit v1.2.3