aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compile.c137
-rw-r--r--environment.c19
-rw-r--r--environment.h8
-rw-r--r--test/lambdas.tm18
-rw-r--r--typecheck.c4
-rw-r--r--types.c21
6 files changed, 159 insertions, 48 deletions
diff --git a/compile.c b/compile.c
index 5d43bd24..7f70d366 100644
--- a/compile.c
+++ b/compile.c
@@ -26,6 +26,37 @@ CORD compile_type_ast(type_ast_t *t)
}
}
+static bool promote(env_t *env, CORD *code, type_t *actual, type_t *needed)
+{
+ if (type_eq(actual, needed))
+ return true;
+
+ if (!can_promote(actual, needed))
+ return false;
+
+ if (actual->tag == IntType || actual->tag == NumType)
+ return true;
+
+ // Automatic dereferencing:
+ if (actual->tag == PointerType && !Match(actual, PointerType)->is_optional
+ && can_promote(Match(actual, PointerType)->pointed, needed)) {
+ *code = CORD_all("*(", *code, ")");
+ return promote(env, code, Match(actual, PointerType)->pointed, needed);
+ }
+
+ // Optional promotion:
+ if (actual->tag == PointerType && needed->tag == PointerType)
+ return true;
+
+ if (needed->tag == ClosureType && actual->tag == FunctionType && type_eq(actual, Match(needed, ClosureType)->fn)) {
+ *code = CORD_all("(closure_t){", *code, ", NULL}");
+ return true;
+ }
+
+ return false;
+}
+
+
CORD compile_declaration(type_t *t, const char *name)
{
if (t->tag == FunctionType) {
@@ -175,19 +206,19 @@ static CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg
{
table_t used_args = {};
CORD code = CORD_EMPTY;
- env_t *default_scope = fresh_scope(env);
- default_scope->locals->fallback = env->globals;
+ env_t *default_scope = global_scope(env);
for (arg_t *spec_arg = spec_args; spec_arg; spec_arg = spec_arg->next) {
// Find keyword:
if (spec_arg->name) {
for (arg_ast_t *call_arg = call_args; call_arg; call_arg = call_arg->next) {
if (call_arg->name && streq(call_arg->name, spec_arg->name)) {
type_t *actual_t = get_type(env, call_arg->value);
- if (!can_promote(actual_t, spec_arg->type))
+ CORD value = compile(env, call_arg->value);
+ if (!promote(env, &value, actual_t, spec_arg->type))
code_err(call_arg->value, "This argument is supposed to be a %T, but this value is a %T", spec_arg->type, actual_t);
Table_str_set(&used_args, call_arg->name, call_arg);
if (code) code = CORD_cat(code, ", ");
- code = CORD_cat(code, compile(env, call_arg->value));
+ code = CORD_cat(code, value);
goto found_it;
}
}
@@ -199,11 +230,12 @@ static CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg
const char *pseudoname = heap_strf("%ld", i++);
if (!Table_str_get(used_args, pseudoname)) {
type_t *actual_t = get_type(env, call_arg->value);
- if (!can_promote(actual_t, spec_arg->type))
+ CORD value = compile(env, call_arg->value);
+ if (!promote(env, &value, actual_t, spec_arg->type))
code_err(call_arg->value, "This argument is supposed to be a %T, but this value is a %T", spec_arg->type, actual_t);
Table_str_set(&used_args, pseudoname, call_arg);
if (code) code = CORD_cat(code, ", ");
- code = CORD_cat(code, compile(env, call_arg->value));
+ code = CORD_cat(code, value);
goto found_it;
}
}
@@ -299,9 +331,9 @@ CORD compile(env_t *env, ast_t *ast)
type_t *lhs_t = get_type(env, binop->lhs);
type_t *rhs_t = get_type(env, binop->rhs);
type_t *operand_t;
- if (can_promote(rhs_t, lhs_t))
+ if (promote(env, &rhs, rhs_t, lhs_t))
operand_t = lhs_t;
- else if (can_promote(lhs_t, rhs_t))
+ else if (promote(env, &lhs, lhs_t, rhs_t))
operand_t = rhs_t;
else
code_err(ast, "I can't do operations between %T and %T", lhs_t, rhs_t);
@@ -450,11 +482,11 @@ CORD compile(env_t *env, ast_t *ast)
type_t *lhs_t = get_type(env, update->lhs);
type_t *rhs_t = get_type(env, update->rhs);
type_t *operand_t;
- if (can_promote(rhs_t, lhs_t))
+ if (promote(env, &rhs, rhs_t, lhs_t))
operand_t = lhs_t;
- else if (can_promote(lhs_t, rhs_t))
+ else if (promote(env, &lhs, lhs_t, rhs_t))
operand_t = rhs_t;
- else if (lhs_t->tag == ArrayType && can_promote(rhs_t, Match(lhs_t, ArrayType)->item_type))
+ else if (lhs_t->tag == ArrayType && promote(env, &rhs, rhs_t, Match(lhs_t, ArrayType)->item_type))
operand_t = lhs_t;
else
code_err(ast, "I can't do operations between %T and %T", lhs_t, rhs_t);
@@ -497,7 +529,7 @@ CORD compile(env_t *env, ast_t *ast)
if (operand_t->tag == TextType) {
return CORD_asprintf("%r = CORD_cat(%r, %r);", lhs, lhs, rhs);
} else if (operand_t->tag == ArrayType) {
- if (can_promote(rhs_t, Match(lhs_t, ArrayType)->item_type)) {
+ if (promote(env, &rhs, rhs_t, Match(lhs_t, ArrayType)->item_type)) {
// arr ++= item
if (update->lhs->tag == Var)
return CORD_all("Array__insert(&", lhs, ", $stack(", rhs, "), 0, ", compile_type_info(env, operand_t), ")");
@@ -691,7 +723,8 @@ CORD compile(env_t *env, ast_t *ast)
case FunctionDef: {
auto fndef = Match(ast, FunctionDef);
CORD name = compile(env, fndef->name);
- CORD signature = CORD_all(fndef->ret_type ? compile_type_ast(fndef->ret_type) : "void", " ", name, "(");
+ type_t *ret_t = fndef->ret_type ? parse_type_ast(env, fndef->ret_type) : Type(VoidType);
+ CORD signature = CORD_all(compile_type(ret_t), " ", name, "(");
for (arg_ast_t *arg = fndef->args; arg; arg = arg->next) {
type_t *arg_type = get_arg_ast_type(env, arg);
signature = CORD_cat(signature, compile_declaration(arg_type, arg->name));
@@ -709,13 +742,19 @@ CORD compile(env_t *env, ast_t *ast)
if (!fndef->is_private)
code = CORD_cat("public ", code);
- env_t *body_scope = fresh_scope(env);
- body_scope->locals->fallback = env->globals;
+ env_t *body_scope = global_scope(env);
for (arg_ast_t *arg = fndef->args; arg; arg = arg->next) {
type_t *arg_type = get_arg_ast_type(env, arg);
set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=arg->name));
}
+ fn_context_t fn_ctx = (fn_context_t){
+ .return_type=ret_t,
+ .closure_scope=NULL,
+ .closed_vars=NULL,
+ };
+ body_scope->fn_ctx = &fn_ctx;
+
CORD body = compile(body_scope, fndef->body);
if (CORD_fetch(body, 0) != '{')
body = CORD_asprintf("{\n%r\n}", body);
@@ -728,19 +767,47 @@ CORD compile(env_t *env, ast_t *ast)
CORD name = CORD_asprintf("lambda$%ld", lambda_number++);
env_t *body_scope = fresh_scope(env);
- body_scope->locals->fallback = env->globals;
for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) {
type_t *arg_type = get_arg_ast_type(env, arg);
set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=arg->name));
}
+
type_t *ret_t = get_type(body_scope, lambda->body);
+ fn_context_t fn_ctx = (fn_context_t){
+ .return_type=ret_t,
+ .closure_scope=body_scope->locals->fallback,
+ .closed_vars=new(table_t),
+ };
+ body_scope->fn_ctx = &fn_ctx;
+ body_scope->locals->fallback = env->globals;
+
CORD code = CORD_all("static ", compile_type(ret_t), " ", name, "(");
for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) {
type_t *arg_type = get_arg_ast_type(env, arg);
code = CORD_all(code, compile_type(arg_type), " ", arg->name, ", ");
}
- code = CORD_cat(code, "void *$userdata)");
+
+ for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next)
+ (void)compile_statement(body_scope, stmt->ast);
+
+ CORD userdata;
+ if (Table_length(*fn_ctx.closed_vars) == 0) {
+ code = CORD_cat(code, "void *$userdata)");
+ userdata = "NULL";
+ } else {
+ CORD def = "typedef struct {";
+ userdata = CORD_all("new(", name, "$userdata_t");
+ for (int64_t i = 1; i <= Table_length(*fn_ctx.closed_vars); i++) {
+ struct { const char *name; binding_t *b; } *entry = Table_entry(*fn_ctx.closed_vars, i);
+ def = CORD_all(def, compile_declaration(entry->b->type, entry->name), "; ");
+ userdata = CORD_all(userdata, ", ", entry->b->code);
+ }
+ userdata = CORD_all(userdata, ")");
+ def = CORD_all(def, "} ", name, "$userdata_t;");
+ env->code->typedefs = CORD_cat(env->code->typedefs, def);
+ code = CORD_all(code, name, "$userdata_t *$userdata)");
+ }
CORD body = CORD_EMPTY;
for (ast_list_t *stmt = Match(lambda->body, Block)->statements; stmt; stmt = stmt->next) {
@@ -750,7 +817,7 @@ CORD compile(env_t *env, ast_t *ast)
body = CORD_all(body, compile_statement(body_scope, FakeAST(Return, stmt->ast)), "\n");
}
env->code->funcs = CORD_all(env->code->funcs, code, " {\n", body, "\n}");
- return CORD_all("(closure_t){", name, ", NULL}");
+ return CORD_all("(closure_t){", name, ", ", userdata, "}");
}
case MethodCall: {
auto call = Match(ast, MethodCall);
@@ -848,13 +915,22 @@ CORD compile(env_t *env, ast_t *ast)
} else if (fn_t->tag == ClosureType) {
fn_t = Match(fn_t, ClosureType)->fn;
arg_t *type_args = Match(fn_t, FunctionType)->args;
- CORD fn_type_code = compile_type(fn_t);
+
+ arg_t *closure_fn_args = NULL;
+ for (arg_t *arg = Match(fn_t, FunctionType)->args; arg; arg = arg->next)
+ closure_fn_args = new(arg_t, .name=arg->name, .type=arg->type, .default_val=arg->default_val, .next=closure_fn_args);
+ closure_fn_args = new(arg_t, .name="$userdata", .type=Type(PointerType, .pointed=Type(MemoryType)), .next=closure_fn_args);
+ REVERSE_LIST(closure_fn_args);
+ CORD fn_type_code = compile_type(Type(FunctionType, .args=closure_fn_args, .ret=Match(fn_t, FunctionType)->ret));
+
CORD closure = compile(env, call->fn);
+ CORD arg_code = compile_arguments(env, ast, type_args, call->args);
+ if (arg_code) arg_code = CORD_cat(arg_code, ", ");
if (call->fn->tag == Var) {
- return CORD_all("((", fn_type_code, ")", closure, ".fn)(", compile_arguments(env, ast, type_args, call->args), ")");
+ return CORD_all("((", fn_type_code, ")", closure, ".fn)(", arg_code, closure, ".userdata)");
} else {
return CORD_all("({ closure_t $closure = ", closure, "; ((", fn_type_code, ")$closure.fn)(",
- compile_arguments(env, ast, type_args, call->args), "); })");
+ arg_code, "$closure.userdata); })");
}
} else {
code_err(call->fn, "This is not a function, it's a %T", fn_t);
@@ -1062,8 +1138,21 @@ CORD compile(env_t *env, ast_t *ast)
}
case Pass: return ";";
case Return: {
+ if (!env->fn_ctx) code_err(ast, "This return statement is not inside any function");
auto ret = Match(ast, Return)->value;
- return ret ? CORD_asprintf("return %r;", compile(env, ret)) : "return;";
+ assert(env->fn_ctx->return_type);
+ if (ret) {
+ type_t *ret_t = get_type(env, ret);
+ CORD value = compile(env, ret);
+ if (!promote(env, &value, ret_t, env->fn_ctx->return_type))
+ code_err(ast, "This function expects a return value of type %T, but this return has type %T",
+ env->fn_ctx->return_type, ret_t);
+ return CORD_all("return ", value, ";");
+ } else {
+ if (env->fn_ctx->return_type->tag != VoidType)
+ code_err(ast, "This function expects a return value of type %T", env->fn_ctx->return_type->tag);
+ return "return;";
+ }
}
// Extern,
case StructDef: {
@@ -1251,10 +1340,10 @@ CORD compile(env_t *env, ast_t *ast)
case TableType: {
type_t *key_t = Match(container_t, TableType)->key_type;
type_t *value_t = Match(container_t, TableType)->value_type;
- if (!can_promote(index_t, key_t))
- code_err(indexing->index, "This value has type %T, but this table can only be index with keys of type %T", index_t, key_t);
CORD table = compile_to_pointer_depth(env, indexing->indexed, 0, false);
CORD key = compile(env, indexing->index);
+ if (!promote(env, &key, index_t, key_t))
+ code_err(indexing->index, "This value has type %T, but this table can only be index with keys of type %T", index_t, key_t);
file_t *f = indexing->index->file;
return CORD_all("$Table_get(", table, ", ", compile_type(key_t), ", ", compile_type(value_t), ", ",
key, ", ", compile_type_info(env, container_t), ", ",
diff --git a/environment.c b/environment.c
index 54b25dbd..d110a7ef 100644
--- a/environment.c
+++ b/environment.c
@@ -176,6 +176,7 @@ env_t *new_compilation_unit(void)
env_t *ns_env = namespace_env(env, global_types[i].name);
$ARRAY_FOREACH(global_types[i].namespace, j, ns_entry_t, entry, {
type_t *type = parse_type_string(ns_env, entry.type_str);
+ if (type->tag == ClosureType) type = Match(type, ClosureType)->fn;
binding_t *b = new(binding_t, .code=entry.code, .type=type);
Table_str_set(namespace, entry.name, b);
}, {})
@@ -184,6 +185,14 @@ env_t *new_compilation_unit(void)
return env;
}
+env_t *global_scope(env_t *env)
+{
+ env_t *scope = new(env_t);
+ *scope = *env;
+ scope->locals = new(table_t, .fallback=env->globals);
+ return scope;
+}
+
env_t *fresh_scope(env_t *env)
{
env_t *scope = new(env_t);
@@ -207,7 +216,15 @@ env_t *namespace_env(env_t *env, const char *namespace_name)
binding_t *get_binding(env_t *env, const char *name)
{
- return Table_str_get(*env->locals, name);
+ binding_t *b = Table_str_get(*env->locals, name);
+ if (!b && env->fn_ctx && env->fn_ctx->closure_scope) {
+ b = Table_str_get(*env->fn_ctx->closure_scope, name);
+ if (b) {
+ Table_str_set(env->fn_ctx->closed_vars, name, b);
+ return new(binding_t, .type=b->type, .code=CORD_all("$userdata->", name));
+ }
+ }
+ return b;
}
binding_t *get_namespace_binding(env_t *env, ast_t *self, const char *name)
diff --git a/environment.h b/environment.h
index 8660fe68..2ab7a9d2 100644
--- a/environment.h
+++ b/environment.h
@@ -17,9 +17,16 @@ typedef struct {
} compilation_unit_t;
typedef struct {
+ type_t *return_type;
+ table_t *closure_scope;
+ table_t *closed_vars;
+} fn_context_t;
+
+typedef struct {
table_t *types, *globals, *locals;
table_t *type_namespaces; // Map of type name -> namespace table
compilation_unit_t *code;
+ fn_context_t *fn_ctx;
CORD scope_prefix;
} env_t;
@@ -29,6 +36,7 @@ typedef struct {
} binding_t;
env_t *new_compilation_unit(void);
+env_t *global_scope(env_t *env);
env_t *fresh_scope(env_t *env);
env_t *namespace_env(env_t *env, const char *namespace_name);
__attribute__((noreturn))
diff --git a/test/lambdas.tm b/test/lambdas.tm
index 86b4b263..d896ea4b 100644
--- a/test/lambdas.tm
+++ b/test/lambdas.tm
@@ -1,5 +1,3 @@
-
-
>> add_one := func(x:Int) x + 1
>> add_one(10)
= 11
@@ -10,3 +8,19 @@
>> asdf := add_one
>> asdf(99)
= 100
+
+
+func make_adder(x:Int)-> func(y:Int)->Int
+ return func(y:Int) x + y
+
+>> add_100 := make_adder(100)
+>> add_100(5)
+= 105
+
+
+func suffix_fn(fn:func(t:Text)->Text, suffix:Text)->func(t:Text)->Text
+ return func(t:Text) fn(t)++suffix
+
+>> shout2 := suffix_fn(Text.upper, "!")
+>> shout2("hello")
+= "HELLO!"
diff --git a/typecheck.c b/typecheck.c
index 88b9d892..671326d0 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -67,7 +67,7 @@ type_t *parse_type_ast(env_t *env, type_ast_t *ast)
}
}
REVERSE_LIST(type_args);
- return Type(FunctionType, .args=type_args, .ret=ret_t);
+ return Type(ClosureType, Type(FunctionType, .args=type_args, .ret=ret_t));
}
case UnknownTypeAST: code_err(ast, "I don't know how to get this type");
}
@@ -610,7 +610,7 @@ type_t *get_type(env_t *env, ast_t *ast)
case Lambda: {
auto lambda = Match(ast, Lambda);
arg_t *args = NULL;
- env_t *scope = fresh_scope(env);
+ env_t *scope = fresh_scope(env); // For now, just use closed variables in scope normally
for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) {
type_t *t = arg->type ? parse_type_ast(env, arg->type) : get_type(env, arg->value);
args = new(arg_t, .name=arg->name, .type=t, .next=args);
diff --git a/types.c b/types.c
index 0b08cf95..10a5a7a8 100644
--- a/types.c
+++ b/types.c
@@ -28,7 +28,7 @@ CORD type_to_cord(type_t *t) {
return CORD_asprintf("{%r=>%r}", type_to_cord(table->key_type), type_to_cord(table->value_type));
}
case ClosureType: {
- return CORD_all("~", type_to_cord(Match(t, ClosureType)->fn));
+ return type_to_cord(Match(t, ClosureType)->fn);
}
case FunctionType: {
CORD c = "func(";
@@ -261,24 +261,7 @@ bool can_promote(type_t *actual, type_t *needed)
}
if (needed->tag == ClosureType && actual->tag == FunctionType)
- return can_promote(actual, Match(needed, ClosureType)->fn);
-
- // Function promotion:
- if (needed->tag == FunctionType && actual->tag == FunctionType) {
- auto needed_fn = Match(needed, FunctionType);
- auto actual_fn = Match(actual, FunctionType);
- for (arg_t *needed_arg = needed_fn->args, *actual_arg = actual_fn->args;
- needed_arg || actual_arg;
- needed_arg = needed_arg->next, actual_arg = actual_arg->next) {
-
- if (!needed_arg || !actual_arg)
- return false;
-
- if (!type_eq(needed_arg->type, actual_arg->type))
- return false;
- }
- return true;
- }
+ return type_eq(actual, Match(needed, ClosureType)->fn);
return false;
}