aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compile.c6
-rw-r--r--typecheck.c618
-rw-r--r--typecheck.h19
3 files changed, 643 insertions, 0 deletions
diff --git a/compile.c b/compile.c
index 66efd1d6..c3dfb723 100644
--- a/compile.c
+++ b/compile.c
@@ -7,6 +7,7 @@
#include "ast.h"
#include "compile.h"
#include "environment.h"
+#include "typecheck.h"
#include "util.h"
CORD compile_type(env_t *env, type_ast_t *t)
@@ -159,9 +160,11 @@ CORD compile(env_t *env, ast_t *ast)
return compile_statement(env, stmts->ast);
CORD code = "{\n";
+ env = fresh_scope(env);
for (ast_list_t *stmt = stmts; stmt; stmt = stmt->next) {
code = CORD_cat(code, compile_statement(env, stmt->ast));
code = CORD_cat(code, "\n");
+ bind_statement(env, stmt->ast);
}
return CORD_cat(code, "}");
}
@@ -217,6 +220,8 @@ CORD compile(env_t *env, ast_t *ast)
CORD kwargs = CORD_asprintf("#define %r(...) ({ struct {", name);
CORD passed_args = CORD_EMPTY;
CORD_appendf(&env->code->funcs, "%r %r_(", fndef->ret_type ? compile_type(env, fndef->ret_type) : "void", name);
+ env_t *body_scope = fresh_scope(env);
+ body_scope->locals->fallback = env->globals;
for (arg_ast_t *arg = fndef->args; arg; arg = arg->next) {
CORD arg_type = compile_type(env, arg->type);
CORD_appendf(&env->code->funcs, "%r %s", arg_type, arg->name);
@@ -224,6 +229,7 @@ CORD compile(env_t *env, ast_t *ast)
CORD_appendf(&kwargs, "%r %s; ", arg_type, arg->name);
CORD_appendf(&passed_args, "$args.%s", arg->name);
if (arg->next) passed_args = CORD_cat(passed_args, ", ");
+ set_binding(body_scope, arg->name, new(binding_t, .type=parse_type_ast(env, arg->type)));
}
CORD_appendf(&kwargs, "} $args = {__VA_ARGS__}; %r_(%r); })\n", name, passed_args);
CORD_appendf(&env->code->staticdefs, "%r", kwargs);
diff --git a/typecheck.c b/typecheck.c
new file mode 100644
index 00000000..4234e566
--- /dev/null
+++ b/typecheck.c
@@ -0,0 +1,618 @@
+// Logic for getting a type from an AST node
+#include <gc.h>
+#include <ctype.h>
+#include <stdarg.h>
+#include <stdlib.h>
+#include <signal.h>
+#include <string.h>
+#include <sys/stat.h>
+
+#include "ast.h"
+#include "environment.h"
+#include "parse.h"
+#include "typecheck.h"
+#include "types.h"
+#include "util.h"
+
+type_t *parse_type_ast(env_t *env, type_ast_t *ast)
+{
+ switch (ast->tag) {
+ case VarTypeAST: {
+ const char *name = Match(ast, VarTypeAST)->name;
+ type_t *t = Table_str_get(env->types, name);
+ if (t) return t;
+ code_err(ast, "I don't know a type with the name '%s'", name);
+ }
+ case PointerTypeAST: {
+ auto ptr = Match(ast, PointerTypeAST);
+ type_t *pointed_t = parse_type_ast(env, ptr->pointed);
+ if (pointed_t->tag == VoidType)
+ code_err(ast, "Void pointers are not supported. You probably meant 'Memory' instead of 'Void'");
+ return Type(PointerType, .is_optional=ptr->is_optional, .pointed=pointed_t, .is_stack=ptr->is_stack, .is_readonly=ptr->is_readonly);
+ }
+ case ArrayTypeAST: {
+ type_ast_t *item_type = Match(ast, ArrayTypeAST)->item;
+ type_t *item_t = parse_type_ast(env, item_type);
+ if (!item_t) code_err(item_type, "I can't figure out what this type is.");
+ if (has_stack_memory(item_t))
+ code_err(item_type, "Arrays can't have stack references because the array may outlive the stack frame.");
+ return Type(ArrayType, .item_type=item_t);
+ }
+ case TableTypeAST: {
+ type_ast_t *key_type_ast = Match(ast, TableTypeAST)->key;
+ type_t *key_type = parse_type_ast(env, key_type_ast);
+ if (!key_type) code_err(key_type_ast, "I can't figure out what type this is.");
+ if (has_stack_memory(key_type))
+ code_err(key_type_ast, "Tables can't have stack references because the array may outlive the stack frame.");
+ type_ast_t *val_type_ast = Match(ast, TableTypeAST)->value;
+ type_t *val_type = parse_type_ast(env, val_type_ast);
+ if (!val_type) code_err(val_type_ast, "I can't figure out what type this is.");
+ if (has_stack_memory(val_type))
+ code_err(val_type_ast, "Tables can't have stack references because the array may outlive the stack frame.");
+ return Type(TableType, .key_type=key_type, .value_type=val_type);
+ }
+ case FunctionTypeAST: {
+ auto fn = Match(ast, FunctionTypeAST);
+ type_t *ret_t = parse_type_ast(env, fn->ret);
+ if (has_stack_memory(ret_t))
+ code_err(fn->ret, "Functions are not allowed to return stack references, because the reference may no longer exist on the stack.");
+ arg_t *type_args = NULL;
+ for (arg_ast_t *arg = fn->args; arg; arg = arg->next) {
+ type_args = new(arg_t, .name=arg->name, .next=type_args);
+ if (arg->type) {
+ type_args->type = parse_type_ast(env, arg->type);
+ } else {
+ type_args->default_val = arg->default_val;
+ type_args->type = get_type(env, arg->default_val);
+ }
+ }
+ REVERSE_LIST(type_args);
+ return Type(FunctionType, .args=type_args, .ret=ret_t);
+ }
+ case UnknownTypeAST: code_err(ast, "I don't know how to get this type");
+ }
+ code_err(ast, "I don't know how to get this type");
+}
+
+type_t *get_iter_type(env_t *env, ast_t *iter)
+{
+ type_t *iter_t = get_type(env, iter);
+ for (;;) {
+ if (iter_t->tag == PointerType) iter_t = Match(iter_t, PointerType)->pointed;
+ // else if (iter_t->tag == VariantType) iter_t = Match(iter_t, VariantType)->variant_of;
+ else break;
+ }
+ type_t *value_t = value_type(iter_t);
+ switch (value_t->tag) {
+ case ArrayType: return Match(value_t, ArrayType)->item_type;
+ case TableType: return table_entry_type(value_t);
+ default:
+ code_err(iter, "I don't know how to iterate over %T values like this", iter_t);
+ break;
+ }
+}
+
+type_t *get_math_type(env_t *env, ast_t *ast, type_t *lhs_t, type_t *rhs_t)
+{
+ (void)env;
+ if (lhs_t->tag == IntType && rhs_t->tag == IntType)
+ return Match(lhs_t, IntType)->bits >= Match(rhs_t, IntType)->bits ? lhs_t : rhs_t;
+ else if (lhs_t->tag == NumType && rhs_t->tag == NumType)
+ return Match(lhs_t, NumType)->bits >= Match(rhs_t, NumType)->bits ? lhs_t : rhs_t;
+ else if (lhs_t->tag == NumType && rhs_t->tag == IntType)
+ return lhs_t;
+ else if (rhs_t->tag == NumType && lhs_t->tag == IntType)
+ return rhs_t;
+
+ code_err(ast, "Math operations between %T and %T are not supported", lhs_t, rhs_t);
+}
+
+void bind_statement(env_t *env, ast_t *statement)
+{
+ switch (statement->tag) {
+ case Declare: {
+ auto decl = Match(statement, Declare);
+ type_t *type = get_type(env, decl->value);
+ set_binding(env, Match(decl->var, Var)->name, new(binding_t, .type=type));
+ break;
+ }
+ default: break;
+ }
+}
+
+static type_t *get_function_def_type(env_t *env, ast_t *ast)
+{
+ auto fn = Match(ast, FunctionDef);
+ arg_t *args = NULL;
+ env_t *scope = fresh_scope(env);
+ for (arg_ast_t *arg = fn->args; arg; arg = arg->next) {
+ type_t *t = arg->type ? parse_type_ast(env, arg->type) : get_type(env, arg->default_val);
+ args = new(arg_t, .name=arg->name, .type=t, .next=args);
+ set_binding(scope, arg->name, new(binding_t, .type=t));
+ }
+ REVERSE_LIST(args);
+
+ type_t *ret = parse_type_ast(scope, fn->ret_type);
+ if (has_stack_memory(ret))
+ code_err(ast, "Functions can't return stack references because the reference may outlive its stack frame.");
+ return Type(FunctionType, .args=args, .ret=ret);
+}
+
+type_t *get_type(env_t *env, ast_t *ast)
+{
+ if (!ast) return NULL;
+ switch (ast->tag) {
+ case Nil: {
+ type_t *pointed = parse_type_ast(env, Match(ast, Nil)->type);
+ return Type(PointerType, .is_optional=true, .pointed=pointed);
+ }
+ case Bool: {
+ return Type(BoolType);
+ }
+ case Int: {
+ auto i = Match(ast, Int);
+ switch (i->precision) {
+ case INT_64BIT: return Type(IntType, .bits=64);
+ case INT_32BIT: return Type(IntType, .bits=32);
+ case INT_16BIT: return Type(IntType, .bits=16);
+ case INT_8BIT: return Type(IntType, .bits=8);
+ default: code_err(ast, "Unsupported precision");
+ }
+ }
+ case Num: {
+ auto n = Match(ast, Num);
+ switch (n->precision) {
+ case NUM_64BIT: return Type(NumType, .bits=64);
+ case NUM_32BIT: return Type(NumType, .bits=32);
+ default: code_err(ast, "Unsupported precision");
+ }
+ }
+ case HeapAllocate: {
+ type_t *pointed = get_type(env, Match(ast, HeapAllocate)->value);
+ if (has_stack_memory(pointed))
+ code_err(ast, "Stack references cannot be moved to the heap because they may outlive the stack frame they were created in.");
+ return Type(PointerType, .is_optional=false, .pointed=pointed);
+ }
+ case StackReference: {
+ ast_t *value = Match(ast, StackReference)->value;
+ type_t *pointed_t = get_type(env, Match(ast, StackReference)->value);
+ bool is_stack = true;
+ // References to heap members/indexes are heap pointers, e.g. v := @Vec{1,2}; &v.x
+ switch (value->tag) {
+ case FieldAccess: {
+ type_t *fielded_t = get_type(env, Match(value, FieldAccess)->fielded);
+ is_stack = fielded_t->tag == PointerType ? Match(fielded_t, PointerType)->is_stack : true;
+ break;
+ }
+ case Index: {
+ type_t *indexed_t = get_type(env, Match(value, Index)->indexed);
+ is_stack = indexed_t->tag == PointerType ? Match(indexed_t, PointerType)->is_stack : true;
+ break;
+ }
+ default: break;
+ }
+ return Type(PointerType, .pointed=pointed_t, .is_stack=is_stack);
+ }
+ case StringJoin: case StringLiteral: {
+ return Type(StringType);
+ }
+ case Var: {
+ auto var = Match(ast, Var);
+ binding_t *b = get_binding(env, var->name);
+ if (b) return b->type;
+ // const char *suggestion = spellcheck(env->bindings, name);
+ // if (suggestion)
+ // code_err(ast, "I don't know what this variable is referring to. Did you mean '%s'?", suggestion);
+ // else
+ // code_err(ast, "I don't know what this variable is referring to.");
+ code_err(ast, "I don't know what \"%s\" refers to", var->name);
+ }
+ case Array: {
+ auto array = Match(ast, Array);
+ type_t *item_type = NULL;
+ if (array->type) {
+ item_type = parse_type_ast(env, array->type);
+ } else if (array->items) {
+ for (ast_list_t *item = array->items; item; item = item->next) {
+ type_t *t2 = get_type(env, item->ast);
+ type_t *merged = item_type ? type_or_type(item_type, t2) : t2;
+ if (!merged)
+ code_err(item->ast,
+ "This array item has type %s, which is different from earlier array items which have type %s",
+ type_to_string(t2), type_to_string(item_type));
+ item_type = merged;
+ }
+ } else {
+ code_err(ast, "I can't figure out what type this array has because it has no members or explicit type");
+ }
+ if (has_stack_memory(item_type))
+ code_err(ast, "Arrays cannot hold stack references, because the array may outlive the stack frame the reference was created in.");
+ return Type(ArrayType, .item_type=item_type);
+ }
+ case Table: {
+ auto table = Match(ast, Table);
+ type_t *key_type = NULL, *value_type = NULL;
+ if (table->key_type && table->value_type) {
+ key_type = parse_type_ast(env, table->key_type);
+ value_type = parse_type_ast(env, table->value_type);
+ } else {
+ if (table->default_value)
+ value_type = get_type(env, table->default_value);
+ for (ast_list_t *entry = table->entries; entry; entry = entry->next) {
+ auto table_entry = Match(entry->ast, TableEntry);
+ type_t *key_t = get_type(env, table_entry->key);
+ type_t *key_merged = key_type ? type_or_type(key_type, key_t) : key_t;
+ if (!key_merged)
+ code_err(table_entry->key,
+ "This table entry has type %s, which is different from earlier table entries which have type %s",
+ type_to_string(key_t), type_to_string(key_type));
+ key_type = key_merged;
+
+ type_t *value_t = get_type(env, table_entry->value);
+ type_t *val_merged = value_type ? type_or_type(value_type, value_t) : value_t;
+ if (!val_merged)
+ code_err(table_entry->value,
+ "This table entry has type %s, which is different from earlier table entries which have type %s",
+ type_to_string(value_t), type_to_string(value_type));
+ value_type = val_merged;
+ }
+ }
+ if (has_stack_memory(key_type) || has_stack_memory(value_type))
+ code_err(ast, "Tables cannot hold stack references because the table may outlive the reference's stack frame.");
+ return Type(TableType, .key_type=key_type, .value_type=value_type);
+ }
+ case TableEntry: {
+ code_err(ast, "This should not be typechecked directly");
+ }
+ case FieldAccess: {
+ auto access = Match(ast, FieldAccess);
+ type_t *fielded_t = get_type(env, access->fielded);
+ type_t *field_t = get_field_type(fielded_t, access->field);
+ if (!field_t)
+ code_err(ast, "I can't find anything called %s on the type %T", access->field, fielded_t);
+ return field_t;
+ }
+ case Index: {
+ auto indexing = Match(ast, Index);
+ type_t *indexed_t = get_type(env, indexing->indexed);
+ if (indexed_t->tag == PointerType && !indexing->index) {
+ auto ptr = Match(indexed_t, PointerType);
+ if (ptr->is_optional)
+ code_err(ast, "You're attempting to dereference a pointer whose type indicates it could be nil");
+ return ptr->pointed;
+ }
+
+ type_t *value_t = value_type(indexed_t);
+ switch (value_t->tag) {
+ case ArrayType: {
+ if (!indexing->index) return indexed_t;
+ type_t *index_t = get_type(env, indexing->index);
+ switch (index_t->tag) {
+ case IntType:
+ return Match(value_t, ArrayType)->item_type;
+ default: code_err(indexing->index, "I only know how to index lists using integers, not %T", index_t);
+ }
+ }
+ case TableType: {
+ return Match(value_t, TableType)->value_type;
+ }
+ // TODO: support ranges like (99..123)[5]
+ // TODO: support slicing arrays like ([1,2,3,4])[2..10]
+ default: {
+ code_err(ast, "I don't know how to index %T values", indexed_t);
+ }
+ }
+ }
+ case KeywordArg: {
+ return get_type(env, Match(ast, KeywordArg)->arg);
+ }
+ case FunctionCall: {
+ auto call = Match(ast, FunctionCall);
+ if (call->extern_return_type)
+ return parse_type_ast(env, call->extern_return_type);
+ type_t *fn_type_t = get_type(env, call->fn);
+ auto fn_type = Match(fn_type_t, FunctionType);
+ return fn_type->ret;
+ }
+ case Block: {
+ auto block = Match(ast, Block);
+ if (LENGTH(block->statements) == 0)
+ return Type(VoidType);
+ ast_list_t *last = block->statements;
+ while (last->next)
+ last = last->next;
+
+ // Early out if the type is knowable without any context from the block:
+ switch (last->ast->tag) {
+ case UpdateAssign: case Assign: case Declare: case FunctionDef: case StructDef: case EnumDef:
+ return Type(VoidType);
+ default: break;
+ }
+
+ env_t *block_env = fresh_scope(env);
+ for (ast_list_t *stmt = block->statements; stmt; stmt = stmt->next) {
+ bind_statement(block_env, stmt->ast);
+ }
+ return get_type(block_env, last->ast);
+ }
+ case Extern: {
+ type_t *t = parse_type_ast(env, Match(ast, Extern)->type);
+ return Match(ast, Extern)->address ? Type(PointerType, .pointed=t, .is_optional=false) : t;
+ }
+ case Declare: case Assign: case DocTest: case LinkerDirective: {
+ return Type(VoidType);
+ }
+ case Use: {
+ const char *path = Match(ast, Use)->path;
+ return Type(PointerType, .pointed=get_file_type(env, path));
+ }
+ case Return: case Stop: case Skip: {
+ return Type(AbortType);
+ }
+ case Pass: return Type(VoidType);
+ case Negative: {
+ type_t *t = get_type(env, Match(ast, Negative)->value);
+ if (t->tag == IntType || t->tag == NumType)
+ return t;
+ code_err(ast, "I only know how to get negatives of numeric types, not %T", t);
+ }
+ case Not: {
+ type_t *t = get_type(env, Match(ast, Not)->value);
+ if (t->tag == IntType || t->tag == NumType || t->tag == BoolType)
+ return t;
+ if (t->tag == PointerType && Match(t, PointerType)->is_optional)
+ return Type(BoolType);
+ code_err(ast, "I only know how to get 'not' of boolean, numeric, and optional pointer types, not %T", t);
+ }
+ case BinaryOp: {
+ auto binop = Match(ast, BinaryOp);
+ type_t *lhs_t = get_type(env, binop->lhs),
+ *rhs_t = get_type(env, binop->rhs);
+
+ switch (binop->op) {
+ case BINOP_AND: {
+ if (lhs_t->tag == BoolType && rhs_t->tag == BoolType) {
+ return lhs_t;
+ } else if (lhs_t->tag == BoolType && rhs_t->tag == AbortType) {
+ return lhs_t;
+ } else if (rhs_t->tag == AbortType) {
+ return lhs_t;
+ } else if (lhs_t->tag == PointerType && rhs_t->tag == PointerType) {
+ auto lhs_ptr = Match(lhs_t, PointerType);
+ auto rhs_ptr = Match(rhs_t, PointerType);
+ if (type_eq(lhs_ptr->pointed, rhs_ptr->pointed))
+ return Type(PointerType, .pointed=lhs_ptr->pointed, .is_optional=lhs_ptr->is_optional || rhs_ptr->is_optional,
+ .is_readonly=lhs_ptr->is_readonly || rhs_ptr->is_readonly);
+ } else if (lhs_t->tag == IntType && rhs_t->tag == IntType) {
+ return get_math_type(env, ast, lhs_t, rhs_t);
+ }
+ code_err(ast, "I can't figure out the type of this `and` expression because the left side is a %T, but the right side is a %T",
+ lhs_t, rhs_t);
+ }
+ case BINOP_OR: {
+ if (lhs_t->tag == BoolType && rhs_t->tag == BoolType) {
+ return lhs_t;
+ } else if (lhs_t->tag == BoolType && rhs_t->tag == AbortType) {
+ return lhs_t;
+ } else if (lhs_t->tag == IntType && rhs_t->tag == IntType) {
+ return get_math_type(env, ast, lhs_t, rhs_t);
+ } else if (lhs_t->tag == PointerType) {
+ auto lhs_ptr = Match(lhs_t, PointerType);
+ if (rhs_t->tag == AbortType) {
+ return Type(PointerType, .pointed=lhs_ptr->pointed, .is_optional=false, .is_readonly=lhs_ptr->is_readonly);
+ } else if (rhs_t->tag == PointerType) {
+ auto rhs_ptr = Match(rhs_t, PointerType);
+ if (type_eq(rhs_ptr->pointed, lhs_ptr->pointed))
+ return Type(PointerType, .pointed=lhs_ptr->pointed, .is_optional=lhs_ptr->is_optional && rhs_ptr->is_optional,
+ .is_readonly=lhs_ptr->is_readonly || rhs_ptr->is_readonly);
+ }
+ }
+ code_err(ast, "I can't figure out the type of this `or` expression because the left side is a %T, but the right side is a %T",
+ lhs_t, rhs_t);
+ }
+ case BINOP_XOR: {
+ if (lhs_t->tag == BoolType && rhs_t->tag == BoolType) {
+ return lhs_t;
+ } else if (lhs_t->tag == IntType && rhs_t->tag == IntType) {
+ return get_math_type(env, ast, lhs_t, rhs_t);
+ }
+
+ code_err(ast, "I can't figure out the type of this `xor` expression because the left side is a %T, but the right side is a %T",
+ lhs_t, rhs_t);
+ }
+ case BINOP_CONCAT: {
+ if (!type_eq(lhs_t, rhs_t))
+ code_err(ast, "The type on the left side of this concatenation doesn't match the right side: %T vs. %T",
+ lhs_t, rhs_t);
+ if (lhs_t->tag == ArrayType && lhs_t->tag == StringType)
+ return lhs_t;
+
+ code_err(ast, "Only array/string value types support concatenation, not %T", lhs_t);
+ }
+ case BINOP_EQ: case BINOP_NE: case BINOP_LT: case BINOP_LE: case BINOP_GT: case BINOP_GE: {
+ if (!can_promote(lhs_t, rhs_t) && !can_promote(rhs_t, lhs_t))
+ code_err(ast, "I can't compare these two different types: %T vs %T", lhs_t, rhs_t);
+ return Type(BoolType);
+ }
+ default: {
+ return get_math_type(env, ast, lhs_t, rhs_t);
+ }
+ }
+ }
+
+ case Reduction: {
+ auto reduction = Match(ast, Reduction);
+ type_t *iter_t = get_type(env, reduction->iter);
+ type_t *value_t = iteration_value_type(iter_t);
+ env_t *scope = fresh_scope(env);
+ set_binding(scope, "$lhs", new(binding_t, .type=value_t));
+ set_binding(scope, "$rhs", new(binding_t, .type=value_t));
+ type_t *t = get_type(scope, reduction->combination);
+ type_t *fallback_t = get_type(env, reduction->fallback);
+ if (fallback_t->tag == AbortType)
+ return t;
+ else if (can_promote(fallback_t, t))
+ return t;
+ else if (can_promote(t, fallback_t))
+ return fallback_t;
+ else
+ return NULL;
+ }
+
+ case UpdateAssign:
+ return Type(VoidType);
+
+ case Min: case Max: {
+ // Unsafe! These types *should* have the same fields and this saves a lot of duplicate code:
+ ast_t *lhs = ast->__data.Min.lhs, *rhs = ast->__data.Min.rhs;
+ // Okay safe again
+
+ type_t *lhs_t = get_type(env, lhs), *rhs_t = get_type(env, rhs);
+ type_t *t = type_or_type(lhs_t, rhs_t);
+ if (!t)
+ code_err(ast, "The two sides of this operation are not compatible: %T vs %T", lhs_t, rhs_t);
+ return t;
+ }
+
+ case Lambda: {
+ auto lambda = Match(ast, Lambda);
+ arg_t *args = NULL;
+ env_t *scope = fresh_scope(env);
+ for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) {
+ type_t *t = arg->type ? parse_type_ast(env, arg->type) : get_type(env, arg->default_val);
+ args = new(arg_t, .name=arg->name, .type=t, .next=args);
+ set_binding(scope, arg->name, new(binding_t, .type=t));
+ }
+ REVERSE_LIST(args);
+
+ type_t *ret = get_type(scope, lambda->body);
+ if (has_stack_memory(ret))
+ code_err(ast, "Functions can't return stack references because the reference may outlive its stack frame.");
+ return Type(ClosureType, Type(FunctionType, .args=args, .ret=ret));
+ }
+
+ case FunctionDef: case StructDef: case EnumDef: {
+ return Type(VoidType);
+ }
+
+ case If: {
+ auto if_ = Match(ast, If);
+ type_t *true_t = get_type(env, if_->body);
+ if (if_->else_body) {
+ type_t *false_t = get_type(env, if_->else_body);
+ type_t *t_either = type_or_type(true_t, false_t);
+ if (!t_either)
+ code_err(if_->else_body,
+ "I was expecting this block to have a %s value (based on earlier clauses), but it actually has a %s value.",
+ type_to_string(true_t), type_to_string(false_t));
+ return t_either;
+ } else {
+ return Type(VoidType);
+ }
+ }
+
+ case While: case For: return Type(VoidType);
+
+ case Unknown: code_err(ast, "I can't figure out the type of: %W", ast);
+ }
+ code_err(ast, "I can't figure out the type of: %s", ast_to_str(ast));
+}
+
+bool is_discardable(env_t *env, ast_t *ast)
+{
+ switch (ast->tag) {
+ case UpdateAssign: case Assign: case Declare: case FunctionDef: case StructDef: case EnumDef: case Use:
+ return true;
+ default: break;
+ }
+ type_t *t = get_type(env, ast);
+ return (t->tag == VoidType || t->tag == AbortType);
+}
+
+
+type_t *get_namespace_type(env_t *env, ast_t *namespace_ast, type_t *type)
+{
+ arg_t *ns_fields = NULL;
+ if (type) {
+ ns_fields = new(arg_t, .name="type", .type=Type(TypeInfoType), .next=ns_fields);
+ if (type->tag == EnumType) {
+ // Add enum constructors:
+ auto enum_ = Match(type, EnumType);
+ for (tag_t *tag = enum_->tags; tag; tag = tag->next) {
+ type_t *constructor_t = Type(FunctionType, .args=Match(tag->type, StructType)->fields,
+ .ret=type);
+ ns_fields = new(arg_t, .name=tag->name, .type=constructor_t, .next=ns_fields);
+ }
+ }
+ }
+
+ for (ast_list_t *stmts = Match(namespace_ast, Block)->statements; stmts; stmts = stmts->next) {
+ ast_t *stmt = stmts->ast;
+ doctest_inner:
+ switch (stmt->tag) {
+ case Declare: {
+ auto decl = Match(stmt, Declare);
+ const char *name = Match(decl->var, Var)->name;
+ type_t *t = get_type(env, decl->value);
+ ns_fields = new(arg_t, .name=name, .type=t, .next=ns_fields);
+ break;
+ }
+ case FunctionDef: {
+ type_t *t = get_function_def_type(env, stmt);
+ const char *name = Match(Match(stmt, FunctionDef)->name, Var)->name;
+ ns_fields = new(arg_t, .name=name, .type=t, .next=ns_fields);
+ break;
+ }
+ case DocTest: {
+ stmt = Match(stmt, DocTest)->expr;
+ goto doctest_inner;
+ }
+ default: break;
+ }
+ }
+ return Type(StructType, .fields=ns_fields);
+}
+
+// typedef struct {
+// file_t *file;
+// env_t *env;
+// ast_t *ast;
+// } parsed_file_info_t;
+
+// static parsed_file_info_t *get_file_info(env_t *env, const char *path)
+// {
+// static table_t cache = {0};
+
+// struct stat file_stat;
+// const char *sss_path = strlen(path) > 4 && streq(path + strlen(path) - 4, ".sss") ? path : heap_strf("%s.sss", path);
+// if (stat(sss_path, &file_stat) == -1)
+// compiler_err(NULL, NULL, NULL, "I can't find the file %s", sss_path);
+
+// parsed_file_info_t *file_info = Table_str_get(&cache, path);
+// if (file_info) return file_info;
+
+// file_t *f = load_file(sss_path);
+// file_info = new(parsed_file_info_t, .file=f);
+// Table_str_set(&cache, path, file_info);
+// file_info->env = new(env_t);
+// *file_info->env = *env;
+
+// file_info->ast = parse_file(f, NULL);
+// table_t *type_bindings = new(table_t, .fallback=&env->global->types);
+// bind_types(env, type_bindings, file_info->ast);
+// populate_types(env, type_bindings, file_info->ast);
+// bind_variables(env, new(table_t, .fallback=&env->global->bindings), file_info->ast);
+// // type_t *ns_t = get_namespace_type(env, ast, NULL);
+// return file_info;
+// }
+
+type_t *get_file_type(env_t *env, const char *path)
+{
+ // auto info = get_file_info(env, path);
+ file_t *f = load_file(path);
+ ast_t *ast = parse_file(f, NULL);
+ if (!ast) compiler_err(NULL, NULL, NULL, "Couldn't parse file: %s", path);
+ return get_namespace_type(env, ast, NULL);
+}
+
+// vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0
diff --git a/typecheck.h b/typecheck.h
new file mode 100644
index 00000000..39d5b01c
--- /dev/null
+++ b/typecheck.h
@@ -0,0 +1,19 @@
+#pragma once
+#include <gc.h>
+#include <stdarg.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include "ast.h"
+#include "types.h"
+
+type_t *parse_type_ast(env_t *env, type_ast_t *ast);
+type_t *get_type(env_t *env, ast_t *ast);
+void bind_statement(env_t *env, ast_t *statement);
+type_t *get_math_type(env_t *env, ast_t *ast, type_t *lhs_t, type_t *rhs_t);
+bool is_discardable(env_t *env, ast_t *ast);
+type_t *get_namespace_type(env_t *env, ast_t *namespace_ast, type_t *type);
+type_t *get_file_type(env_t *env, const char *path);
+type_t *get_iter_type(env_t *env, ast_t *iter);
+
+// vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0