diff options
| author | Bruce Hill <bruce@bruce-hill.com> | 2025-09-06 14:46:15 -0400 |
|---|---|---|
| committer | Bruce Hill <bruce@bruce-hill.com> | 2025-09-06 14:46:15 -0400 |
| commit | d8a48f64111f542f3afeb5d6e47ff092f9278d9f (patch) | |
| tree | 07c364503025bb2a26edd7b26f1ba9e8d25340f6 /src/typecheck.c | |
| parent | 12345a85d9c7d7a56ddf323247a4bdf347021b73 (diff) | |
| parent | 73246764f88f6f652316ee0c138a990d836698a7 (diff) | |
Merge branch 'main' into optional-list-indexing
Diffstat (limited to 'src/typecheck.c')
| -rw-r--r-- | src/typecheck.c | 102 |
1 files changed, 59 insertions, 43 deletions
diff --git a/src/typecheck.c b/src/typecheck.c index 695f7cbc..d7e87e65 100644 --- a/src/typecheck.c +++ b/src/typecheck.c @@ -336,12 +336,12 @@ void bind_statement(env_t *env, ast_t *statement) { case FunctionDef: { DeclareMatch(def, statement, FunctionDef); const char *name = Match(def->name, Var)->name; - type_t *type = get_function_def_type(env, statement); + type_t *type = get_function_type(env, statement); set_binding(env, name, type, namespace_name(env, env->namespace, Text$from_str(name))); break; } case ConvertDef: { - type_t *type = get_function_def_type(env, statement); + type_t *type = get_function_type(env, statement); type_t *ret_t = Match(type, FunctionType)->ret; const char *name = get_type_name(ret_t); if (!name) @@ -577,10 +577,24 @@ void bind_statement(env_t *env, ast_t *statement) { } } -type_t *get_function_def_type(env_t *env, ast_t *ast) { - arg_ast_t *arg_asts = ast->tag == FunctionDef ? Match(ast, FunctionDef)->args : Match(ast, ConvertDef)->args; - type_ast_t *ret_type = - ast->tag == FunctionDef ? Match(ast, FunctionDef)->ret_type : Match(ast, ConvertDef)->ret_type; +type_t *get_function_type(env_t *env, ast_t *ast) { + arg_ast_t *arg_asts; + type_ast_t *ret_ast; + switch (ast->tag) { + case FunctionDef: + arg_asts = Match(ast, FunctionDef)->args; + ret_ast = Match(ast, FunctionDef)->ret_type; + break; + case ConvertDef: + arg_asts = Match(ast, ConvertDef)->args; + ret_ast = Match(ast, ConvertDef)->ret_type; + break; + case Lambda: + arg_asts = Match(ast, Lambda)->args; + ret_ast = Match(ast, Lambda)->ret_type; + break; + default: code_err(ast, "This was expected to be a function definition of some sort"); + } arg_t *args = NULL; env_t *scope = fresh_scope(env); for (arg_ast_t *arg = arg_asts; arg; arg = arg->next) { @@ -590,10 +604,40 @@ type_t *get_function_def_type(env_t *env, ast_t *ast) { } REVERSE_LIST(args); - type_t *ret = ret_type ? parse_type_ast(scope, ret_type) : Type(VoidType); - if (has_stack_memory(ret)) - code_err(ast, "Functions can't return stack references because the reference may outlive its stack frame."); - return Type(FunctionType, .args = args, .ret = ret); + if (ast->tag == Lambda) { + ast_t *body = Match(ast, Lambda)->body; + + scope->fn = NULL; + type_t *ret_t = get_type(scope, body); + if (ret_t->tag == ReturnType) ret_t = Match(ret_t, ReturnType)->ret; + if (ret_t->tag == AbortType) ret_t = Type(VoidType); + + if (ret_t->tag == OptionalType && !Match(ret_t, OptionalType)->type) + code_err(body, "This function doesn't return a specific optional type"); + + if (ret_ast) { + type_t *declared = parse_type_ast(env, ret_ast); + if (can_promote(ret_t, declared)) ret_t = declared; + else + code_err(ast, "This function was declared to return a value of type ", type_to_str(declared), + ", but actually returns a value of type ", type_to_str(ret_t)); + } + + if (has_stack_memory(ret_t)) + code_err(ast, "Functions can't return stack references because the reference may outlive its stack frame."); + return Type(ClosureType, Type(FunctionType, .args = args, .ret = ret_t)); + } else { + type_t *ret_t = ret_ast ? parse_type_ast(scope, ret_ast) : Type(VoidType); + if (has_stack_memory(ret_t)) + code_err(ast, "Functions can't return stack references because the reference may outlive its stack frame."); + return Type(FunctionType, .args = args, .ret = ret_t); + } +} + +type_t *get_function_return_type(env_t *env, ast_t *ast) { + type_t *fn_t = get_function_type(env, ast); + if (fn_t->tag == ClosureType) fn_t = Match(fn_t, ClosureType)->fn; + return Match(fn_t, FunctionType)->ret; } type_t *get_method_type(env_t *env, ast_t *self, const char *name) { @@ -1085,7 +1129,7 @@ type_t *get_type(env_t *env, ast_t *ast) { } case Return: { ast_t *val = Match(ast, Return)->value; - if (env->fn_ret) env = with_enum_scope(env, env->fn_ret); + if (env->fn) env = with_enum_scope(env, get_function_return_type(env, env->fn)); return Type(ReturnType, .ret = (val ? get_type(env, val) : Type(VoidType))); } case Stop: @@ -1314,7 +1358,7 @@ type_t *get_type(env_t *env, ast_t *ast) { } } } else if ((ast->tag == Divide || ast->tag == Mod || ast->tag == Mod1) && is_numeric_type(rhs_t)) { - binding_t *b = get_namespace_binding(env, binop.lhs, binop_method_name(ast->tag)); + binding_t *b = get_namespace_binding(env, binop.lhs, binop_info[ast->tag].method_name); if (b && b->type->tag == FunctionType) { DeclareMatch(fn, b->type, FunctionType); if (type_eq(fn->ret, lhs_t)) { @@ -1373,7 +1417,8 @@ type_t *get_type(env_t *env, ast_t *ast) { code_err(reduction->iter, "I don't know how to do a reduction over ", type_to_str(iter_t), " values"); if (reduction->key && !(reduction->op == Min || reduction->op == Max)) { env_t *item_scope = fresh_scope(env); - set_binding(item_scope, "$", iterated, EMPTY_TEXT); + const char *op_str = binop_info[reduction->op].operator; + set_binding(item_scope, op_str, iterated, EMPTY_TEXT); iterated = get_type(item_scope, reduction->key); } return iterated->tag == OptionalType ? iterated : Type(OptionalType, .type = iterated); @@ -1393,36 +1438,7 @@ type_t *get_type(env_t *env, ast_t *ast) { return t; } - case Lambda: { - DeclareMatch(lambda, ast, Lambda); - arg_t *args = NULL; - env_t *scope = fresh_scope(env); // For now, just use closed variables in scope normally - for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) { - type_t *t = get_arg_ast_type(env, arg); - args = new (arg_t, .name = arg->name, .alias = arg->alias, .type = t, .next = args); - set_binding(scope, arg->name, t, EMPTY_TEXT); - } - REVERSE_LIST(args); - - type_t *ret = get_type(scope, lambda->body); - if (ret->tag == ReturnType) ret = Match(ret, ReturnType)->ret; - if (ret->tag == AbortType) ret = Type(VoidType); - - if (ret->tag == OptionalType && !Match(ret, OptionalType)->type) - code_err(lambda->body, "This function doesn't return a specific optional type"); - - if (lambda->ret_type) { - type_t *declared = parse_type_ast(env, lambda->ret_type); - if (can_promote(ret, declared)) ret = declared; - else - code_err(ast, "This function was declared to return a value of type ", type_to_str(declared), - ", but actually returns a value of type ", type_to_str(ret)); - } - - if (has_stack_memory(ret)) - code_err(ast, "Functions can't return stack references because the reference may outlive its stack frame."); - return Type(ClosureType, Type(FunctionType, .args = args, .ret = ret)); - } + case Lambda: return get_function_type(env, ast); case FunctionDef: case ConvertDef: |
