aboutsummaryrefslogtreecommitdiff
path: root/typecheck.c
diff options
context:
space:
mode:
Diffstat (limited to 'typecheck.c')
-rw-r--r--typecheck.c44
1 files changed, 25 insertions, 19 deletions
diff --git a/typecheck.c b/typecheck.c
index e61873d1..0dc7cdc0 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -1,9 +1,9 @@
// Logic for getting a type from an AST node
-#include <gc.h>
#include <ctype.h>
+#include <gc.h>
+#include <signal.h>
#include <stdarg.h>
#include <stdlib.h>
-#include <signal.h>
#include <string.h>
#include <sys/stat.h>
@@ -678,11 +678,7 @@ type_t *get_type(env_t *env, ast_t *ast)
else if (streq(call->name, "heapify")) return Type(VoidType);
else if (streq(call->name, "heap_push")) return Type(VoidType);
else if (streq(call->name, "heap_pop")) return Match(self_value_t, ArrayType)->item_type;
- else if (streq(call->name, "pairs")) {
- type_t *ref_t = Type(PointerType, .pointed=Match(self_value_t, ArrayType)->item_type, .is_stack=true);
- arg_t *args = new(arg_t, .name="x", .type=ref_t, .next=new(arg_t, .name="y", .type=ref_t));
- return Type(ClosureType, .fn=Type(FunctionType, .args=args, .ret=Type(BoolType)));
- } else code_err(ast, "There is no '%s' method for arrays", call->name);
+ else code_err(ast, "There is no '%s' method for arrays", call->name);
}
case TableType: {
auto table = Match(self_value_t, TableType);
@@ -740,7 +736,7 @@ type_t *get_type(env_t *env, ast_t *ast)
case Return: {
ast_t *val = Match(ast, Return)->value;
// Support unqualified enum return values:
- if (env->fn_ctx && env->fn_ctx->return_type->tag == EnumType) {
+ if (env->fn_ctx && env->fn_ctx->return_type && env->fn_ctx->return_type->tag == EnumType) {
env = fresh_scope(env);
auto enum_ = Match(env->fn_ctx->return_type, EnumType);
env_t *ns_env = enum_->env;
@@ -903,25 +899,35 @@ type_t *get_type(env_t *env, ast_t *ast)
case ArrayType: value_t = Match(iter_value_t, ArrayType)->item_type; break;
case TableType: value_t = Match(iter_value_t, TableType)->key_type; break;
case FunctionType: case ClosureType: {
+ // Iterator function
auto fn = iter_value_t->tag == ClosureType ?
Match(Match(iter_value_t, ClosureType)->fn, FunctionType) : Match(iter_value_t, FunctionType);
- if (!fn->args || fn->args->next)
- code_err(reduction->iter, "I expected this iterable to have exactly one argument, not %T", iter_value_t);
- type_t *arg_type = get_arg_type(env, fn->args);
- if (arg_type->tag != PointerType)
- code_err(reduction->iter, "I expected this iterable to have exactly one stack reference argument, not %T", arg_type);
- auto ptr = Match(arg_type, PointerType);
- if (!ptr->is_stack || ptr->is_optional || ptr->is_readonly)
- code_err(reduction->iter, "I expected this iterable to have exactly one stack reference argument, not %T", arg_type);
- value_t = ptr->pointed;
+ if (fn->args)
+ code_err(reduction->iter, "I expected this iterator function to not take any arguments, but it's %T", iter_value_t);
+ if (fn->ret->tag != EnumType)
+ code_err(reduction->iter, "I expected this iterator function to return an enum, but it's %T", iter_value_t);
+ value_t = NULL;
+ for (tag_t *tag = Match(fn->ret, EnumType)->tags; tag; tag = tag->next) {
+ if (streq(tag->name, "Next")) {
+ arg_t *fields = Match(tag->type, StructType)->fields;
+ if (!fields || fields->next)
+ code_err(reduction->iter,
+ "I expected this iterator function to return an enum with a Next() that has exactly one value, not %T",
+ tag->type);
+ value_t = fields->type;
+ break;
+ }
+ }
+ if (!value_t)
+ code_err(reduction->iter, "This iterator function doesn't return an enum with a Next() value");
break;
}
default: code_err(reduction->iter, "I don't know how to do a reduction over %T values", iter_t);
}
env_t *scope = fresh_scope(env);
- set_binding(scope, "$reduction", new(binding_t, .type=value_t, .code="reduction"));
- set_binding(scope, "$iter_value", new(binding_t, .type=value_t, .code="iter_value"));
+ set_binding(scope, "$reduction", new(binding_t, .type=value_t));
+ set_binding(scope, "$iter_value", new(binding_t, .type=value_t));
type_t *t = get_type(scope, reduction->combination);
if (!reduction->fallback)
return t;