aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2024-05-14 13:30:46 -0400
committerBruce Hill <bruce@bruce-hill.com>2024-05-14 13:30:46 -0400
commit1924f75647389003873c4532449484c65c83e79b (patch)
tree601018eb6a80e57e2dfb6623dd844812dd928394
parent31814db0a6e698e218121b12838f411358bf78a5 (diff)
Support corecursive structs
-rw-r--r--compile.c3
-rw-r--r--test/structs.tm5
-rw-r--r--typecheck.c52
-rw-r--r--typecheck.h1
4 files changed, 61 insertions, 0 deletions
diff --git a/compile.c b/compile.c
index 9b3d7753..d52ef978 100644
--- a/compile.c
+++ b/compile.c
@@ -2161,6 +2161,9 @@ module_code_t compile_file(ast_t *ast)
Table$str_set(env->imports, name, env);
for (ast_list_t *stmt = Match(ast, Block)->statements; stmt; stmt = stmt->next) {
+ prebind_statement(env, stmt->ast);
+ }
+ for (ast_list_t *stmt = Match(ast, Block)->statements; stmt; stmt = stmt->next) {
// Hack: make sure global is bound as foo$var:
if (stmt->ast->tag == Declare && Match(Match(stmt->ast, Declare)->var, Var)->name[0] != '_')
env->scope_prefix = heap_strf("%s$", name);
diff --git a/test/structs.tm b/test/structs.tm
index ad8e7ad6..4fdec6ae 100644
--- a/test/structs.tm
+++ b/test/structs.tm
@@ -4,6 +4,9 @@ struct Mixed(x:Int, text:Text)
struct LinkedList(x:Int, next=!@LinkedList)
struct Password(text:Text; secret)
+struct CorecursiveA(other:@CorecursiveB?)
+struct CorecursiveB(other=!@CorecursiveA)
+
func test_literals():
>> x := Pair(10, 20)
= Pair(x=10, y=20)
@@ -63,3 +66,5 @@ func main():
>> users_by_password[my_pass]
= "User1"
+ >> CorecursiveA(@CorecursiveB())
+
diff --git a/typecheck.c b/typecheck.c
index 87f7544c..139da0b6 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -127,12 +127,64 @@ static env_t *load_module(env_t *env, ast_t *use_ast)
if (!ast) errx(1, "Could not compile!");
for (ast_list_t *stmt = Match(ast, Block)->statements; stmt; stmt = stmt->next) {
+ prebind_statement(module_env, stmt->ast);
+ }
+ for (ast_list_t *stmt = Match(ast, Block)->statements; stmt; stmt = stmt->next) {
bind_statement(module_env, stmt->ast);
}
Table$str_set(env->imports, name, module_env);
return module_env;
}
+void prebind_statement(env_t *env, ast_t *statement)
+{
+ switch (statement->tag) {
+ case DocTest: {
+ prebind_statement(env, Match(statement, DocTest)->expr);
+ break;
+ }
+ case StructDef: {
+ auto def = Match(statement, StructDef);
+ if (get_binding(env, def->name))
+ code_err(statement, "A %T called '%s' has already been defined", get_binding(env, def->name)->type, def->name);
+
+ env_t *ns_env = namespace_env(env, def->name);
+ type_t *type = Type(StructType, .name=def->name, .opaque=true, .env=ns_env); // placeholder
+ Table$str_set(env->types, def->name, type);
+ for (ast_list_t *stmt = def->namespace ? Match(def->namespace, Block)->statements : NULL; stmt; stmt = stmt->next) {
+ prebind_statement(ns_env, stmt->ast);
+ }
+ break;
+ }
+ case EnumDef: {
+ auto def = Match(statement, EnumDef);
+ if (get_binding(env, def->name))
+ code_err(statement, "A %T called '%s' has already been defined", get_binding(env, def->name)->type, def->name);
+
+ env_t *ns_env = namespace_env(env, def->name);
+ type_t *type = Type(EnumType, .name=def->name, .opaque=true, .env=ns_env); // placeholder
+ Table$str_set(env->types, def->name, type);
+ for (ast_list_t *stmt = def->namespace ? Match(def->namespace, Block)->statements : NULL; stmt; stmt = stmt->next) {
+ prebind_statement(ns_env, stmt->ast);
+ }
+ break;
+ }
+ case LangDef: {
+ auto def = Match(statement, LangDef);
+ if (get_binding(env, def->name))
+ code_err(statement, "A %T called '%s' has already been defined", get_binding(env, def->name)->type, def->name);
+
+ env_t *ns_env = namespace_env(env, def->name);
+ type_t *type = Type(TextType, .lang=def->name, .env=ns_env);
+ Table$str_set(env->types, def->name, type);
+ for (ast_list_t *stmt = def->namespace ? Match(def->namespace, Block)->statements : NULL; stmt; stmt = stmt->next)
+ prebind_statement(ns_env, stmt->ast);
+ break;
+ }
+ default: break;
+ }
+}
+
void bind_statement(env_t *env, ast_t *statement)
{
switch (statement->tag) {
diff --git a/typecheck.h b/typecheck.h
index fd9f1f73..ea4fe0e3 100644
--- a/typecheck.h
+++ b/typecheck.h
@@ -13,6 +13,7 @@
type_t *parse_type_ast(env_t *env, type_ast_t *ast);
type_t *get_type(env_t *env, ast_t *ast);
+void prebind_statement(env_t *env, ast_t *statement);
void bind_statement(env_t *env, ast_t *statement);
type_t *get_math_type(env_t *env, ast_t *ast, type_t *lhs_t, type_t *rhs_t);
bool is_discardable(env_t *env, ast_t *ast);