diff --git a/compile.c b/compile.c index 9b3d775..d52ef97 100644 --- a/compile.c +++ b/compile.c @@ -2160,6 +2160,9 @@ module_code_t compile_file(ast_t *ast) env->file_prefix = heap_strf("%s$", name); Table$str_set(env->imports, name, env); + for (ast_list_t *stmt = Match(ast, Block)->statements; stmt; stmt = stmt->next) { + prebind_statement(env, stmt->ast); + } for (ast_list_t *stmt = Match(ast, Block)->statements; stmt; stmt = stmt->next) { // Hack: make sure global is bound as foo$var: if (stmt->ast->tag == Declare && Match(Match(stmt->ast, Declare)->var, Var)->name[0] != '_') diff --git a/test/structs.tm b/test/structs.tm index ad8e7ad..4fdec6a 100644 --- a/test/structs.tm +++ b/test/structs.tm @@ -4,6 +4,9 @@ struct Mixed(x:Int, text:Text) struct LinkedList(x:Int, next=!@LinkedList) struct Password(text:Text; secret) +struct CorecursiveA(other:@CorecursiveB?) +struct CorecursiveB(other=!@CorecursiveA) + func test_literals(): >> x := Pair(10, 20) = Pair(x=10, y=20) @@ -63,3 +66,5 @@ func main(): >> users_by_password[my_pass] = "User1" + >> CorecursiveA(@CorecursiveB()) + diff --git a/typecheck.c b/typecheck.c index 87f7544..139da0b 100644 --- a/typecheck.c +++ b/typecheck.c @@ -126,6 +126,9 @@ static env_t *load_module(env_t *env, ast_t *use_ast) ast_t *ast = parse_file(f, NULL); if (!ast) errx(1, "Could not compile!"); + for (ast_list_t *stmt = Match(ast, Block)->statements; stmt; stmt = stmt->next) { + prebind_statement(module_env, stmt->ast); + } for (ast_list_t *stmt = Match(ast, Block)->statements; stmt; stmt = stmt->next) { bind_statement(module_env, stmt->ast); } @@ -133,6 +136,55 @@ static env_t *load_module(env_t *env, ast_t *use_ast) return module_env; } +void prebind_statement(env_t *env, ast_t *statement) +{ + switch (statement->tag) { + case DocTest: { + prebind_statement(env, Match(statement, DocTest)->expr); + break; + } + case StructDef: { + auto def = Match(statement, StructDef); + if (get_binding(env, def->name)) + code_err(statement, "A %T called '%s' has already been defined", get_binding(env, def->name)->type, def->name); + + env_t *ns_env = namespace_env(env, def->name); + type_t *type = Type(StructType, .name=def->name, .opaque=true, .env=ns_env); // placeholder + Table$str_set(env->types, def->name, type); + for (ast_list_t *stmt = def->namespace ? Match(def->namespace, Block)->statements : NULL; stmt; stmt = stmt->next) { + prebind_statement(ns_env, stmt->ast); + } + break; + } + case EnumDef: { + auto def = Match(statement, EnumDef); + if (get_binding(env, def->name)) + code_err(statement, "A %T called '%s' has already been defined", get_binding(env, def->name)->type, def->name); + + env_t *ns_env = namespace_env(env, def->name); + type_t *type = Type(EnumType, .name=def->name, .opaque=true, .env=ns_env); // placeholder + Table$str_set(env->types, def->name, type); + for (ast_list_t *stmt = def->namespace ? Match(def->namespace, Block)->statements : NULL; stmt; stmt = stmt->next) { + prebind_statement(ns_env, stmt->ast); + } + break; + } + case LangDef: { + auto def = Match(statement, LangDef); + if (get_binding(env, def->name)) + code_err(statement, "A %T called '%s' has already been defined", get_binding(env, def->name)->type, def->name); + + env_t *ns_env = namespace_env(env, def->name); + type_t *type = Type(TextType, .lang=def->name, .env=ns_env); + Table$str_set(env->types, def->name, type); + for (ast_list_t *stmt = def->namespace ? Match(def->namespace, Block)->statements : NULL; stmt; stmt = stmt->next) + prebind_statement(ns_env, stmt->ast); + break; + } + default: break; + } +} + void bind_statement(env_t *env, ast_t *statement) { switch (statement->tag) { diff --git a/typecheck.h b/typecheck.h index fd9f1f7..ea4fe0e 100644 --- a/typecheck.h +++ b/typecheck.h @@ -13,6 +13,7 @@ type_t *parse_type_ast(env_t *env, type_ast_t *ast); type_t *get_type(env_t *env, ast_t *ast); +void prebind_statement(env_t *env, ast_t *statement); void bind_statement(env_t *env, ast_t *statement); type_t *get_math_type(env_t *env, ast_t *ast, type_t *lhs_t, type_t *rhs_t); bool is_discardable(env_t *env, ast_t *ast);