From f4dee58f03774d033c55d890356cd93c3e2462fb Mon Sep 17 00:00:00 2001 From: Bruce Hill Date: Thu, 4 Jul 2024 18:27:08 -0400 Subject: [PATCH] Check for functions that don't return when they need to, as well as a fix for 'when' statement typing --- compile.c | 9 ++++++++- typecheck.c | 23 +++++++++++++++++------ 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/compile.c b/compile.c index e9648c3..09314c8 100644 --- a/compile.c +++ b/compile.c @@ -597,6 +597,10 @@ CORD compile_statement(env_t *env, ast_t *ast) }; body_scope->fn_ctx = &fn_ctx; + + if (ret_t->tag != VoidType && ret_t->tag != AbortType && get_type(body_scope, fndef->body)->tag != AbortType) + code_err(ast, "This function can reach the end without returning a %T value!", ret_t); + CORD body = compile_statement(body_scope, fndef->body); if (CORD_fetch(body, 0) != '{') body = CORD_asprintf("{\n%r\n}", body); @@ -733,6 +737,9 @@ CORD compile_statement(env_t *env, ast_t *ast) } if (ret) { + if (env->fn_ctx->return_type->tag == VoidType || env->fn_ctx->return_type->tag == AbortType) + code_err(ast, "This function is not supposed to return any values, according to its type signature"); + env = with_enum_scope(env, env->fn_ctx->return_type); type_t *ret_t = get_type(env, ret); CORD value = compile(env, ret); @@ -742,7 +749,7 @@ CORD compile_statement(env_t *env, ast_t *ast) return CORD_all(code, "return ", value, ";"); } else { if (env->fn_ctx->return_type->tag != VoidType) - code_err(ast, "This function expects a return value of type %T", env->fn_ctx->return_type->tag); + code_err(ast, "This function expects you to return a %T value", env->fn_ctx->return_type); return CORD_all(code, "return;"); } } diff --git a/typecheck.c b/typecheck.c index 86d6625..53946dc 100644 --- a/typecheck.c +++ b/typecheck.c @@ -949,6 +949,7 @@ type_t *get_type(env_t *env, ast_t *ast) case When: { auto when = Match(ast, When); type_t *subject_t = get_type(env, when->subject); + type_t *overall_t = NULL; if (subject_t->tag == PointerType) { if (!Match(subject_t, PointerType)->is_optional) code_err(when->subject, "This %T pointer type is not optional, so this 'when' statement is tautological", subject_t); @@ -961,12 +962,25 @@ type_t *get_type(env_t *env, ast_t *ast) if (handled_at) code_err(clause->tag_name, "This 'when' statement has already handled the case of non-null pointers!"); handled_at = true; + + assert(clause->args); + env_t *scope = fresh_scope(env); + auto ptr = Match(subject_t, PointerType); + set_binding(scope, Match(clause->args->ast, Var)->name, + new(binding_t, .type=Type(PointerType, .pointed=ptr->pointed, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly))); + + type_t *clause_type = get_type(scope, clause->body); + type_t *merged = type_or_type(overall_t, clause_type); + if (!merged) + code_err(clause->body, "The type of this branch is %T, which conflicts with the earlier branch type of %T", + clause_type, overall_t); + overall_t = merged; } if (!handled_at) code_err(ast, "This 'when' statement doesn't handle non-null pointers"); if (!when->else_body) code_err(ast, "This 'when' statement doesn't handle null pointers"); - return Type(VoidType); + return overall_t; } if (subject_t->tag != EnumType) @@ -984,7 +998,6 @@ type_t *get_type(env_t *env, ast_t *ast) for (tag_t *tag = tags; tag; tag = tag->next) matches = new(match_t, .name=tag->name, .type=tag->type, .next=matches); - type_t *overall_t = NULL; for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { const char *tag_name = Match(clause->tag_name, Var)->name; type_t *tag_type = NULL; @@ -1051,8 +1064,7 @@ type_t *get_type(env_t *env, ast_t *ast) code_err(when->else_body, "I was expecting this block to have a %T value (based on earlier clauses), but it actually has a %T value.", overall_t, else_t); - // return merged; - return Type(VoidType); + return merged; } else { CORD unhandled = CORD_EMPTY; for (match_t *m = matches; m; m = m->next) { @@ -1061,8 +1073,7 @@ type_t *get_type(env_t *env, ast_t *ast) } if (unhandled) code_err(ast, "This 'when' statement doesn't handle the tag(s): %s", CORD_to_const_char_star(unhandled)); - // return overall_t; - return Type(VoidType); + return overall_t; } }