First pass at lambdas/closures

This commit is contained in:
Bruce Hill 2024-03-09 14:02:19 -05:00
parent 0921f3723b
commit 955f047e06
7 changed files with 100 additions and 40 deletions

View File

@ -30,4 +30,8 @@ typedef struct table_s {
void *default_value;
} table_t;
typedef struct {
void *fn, *userdata;
} closure_t;
// vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0

View File

@ -52,6 +52,8 @@ typedef struct TypeInfo {
.tag=TableInfo, .TableInfo.key=key_expr, .TableInfo.value=value_expr})
#define $FunctionInfo(typestr) &((TypeInfo){.size=sizeof(void*), .align=__alignof__(void*), \
.tag=FunctionInfo, .FunctionInfo.type_str=typestr})
#define $ClosureInfo(typestr) &((TypeInfo){.size=2*sizeof(void*), .align=__alignof__(void*), \
.tag=FunctionInfo, .FunctionInfo.type_str=typestr})
#define $TypeInfoInfo(typestr) &((TypeInfo){.size=sizeof(TypeInfo), .align=__alignof__(TypeInfo), \
.tag=TypeInfoInfo, .TypeInfoInfo.type_str=typestr})

122
compile.c
View File

@ -26,6 +26,21 @@ CORD compile_type_ast(type_ast_t *t)
}
}
CORD compile_declaration(type_t *t, const char *name)
{
if (t->tag == FunctionType) {
auto fn = Match(t, FunctionType);
CORD code = CORD_all(compile_type(fn->ret), " (*", name, ")(");
for (arg_t *arg = fn->args; arg; arg = arg->next) {
code = CORD_all(code, compile_type(arg->type));
if (arg->next) code = CORD_cat(code, ", ");
}
return CORD_all(code, ")");
} else {
return CORD_all(compile_type(t), " ", name);
}
}
CORD compile_type(type_t *t)
{
switch (t->tag) {
@ -41,8 +56,16 @@ CORD compile_type(type_t *t)
}
case ArrayType: return "array_t";
case TableType: return "table_t";
case FunctionType: return "const void*";
case ClosureType: compiler_err(NULL, NULL, NULL, "Not implemented");
case FunctionType: {
auto fn = Match(t, FunctionType);
CORD code = CORD_all(compile_type(fn->ret), " (*)(");
for (arg_t *arg = fn->args; arg; arg = arg->next) {
code = CORD_all(code, compile_type(arg->type));
if (arg->next) code = CORD_cat(code, ", ");
}
return CORD_all(code, ")");
}
case ClosureType: return "closure_t";
case PointerType: return CORD_cat(compile_type(Match(t, PointerType)->pointed), "*");
case StructType: return CORD_cat(Match(t, StructType)->name, "_t");
case EnumType: return CORD_cat(Match(t, EnumType)->name, "_t");
@ -560,8 +583,7 @@ CORD compile(env_t *env, ast_t *ast)
case Declare: {
auto decl = Match(ast, Declare);
type_t *t = get_type(env, decl->value);
// return CORD_asprintf("auto %r = %r;", compile(env, decl->var), compile(env, decl->value));
return CORD_asprintf("%r %r = %r;", compile_type(t), compile(env, decl->var), compile(env, decl->value));
return CORD_all(compile_declaration(t, Match(decl->var, Var)->name), " = ", compile(env, decl->value), ";");
}
case Assign: {
auto assign = Match(ast, Assign);
@ -673,7 +695,7 @@ CORD compile(env_t *env, ast_t *ast)
CORD signature = CORD_all(fndef->ret_type ? compile_type_ast(fndef->ret_type) : "void", " ", name, "(");
for (arg_ast_t *arg = fndef->args; arg; arg = arg->next) {
type_t *arg_type = get_arg_ast_type(env, arg);
CORD_appendf(&signature, "%r %s", compile_type(arg_type), arg->name);
signature = CORD_cat(signature, compile_declaration(arg_type, arg->name));
if (arg->next) signature = CORD_cat(signature, ", ");
}
signature = CORD_cat(signature, ")");
@ -701,6 +723,32 @@ CORD compile(env_t *env, ast_t *ast)
env->code->funcs = CORD_all(env->code->funcs, code, " ", body);
return CORD_EMPTY;
}
case Lambda: {
auto lambda = Match(ast, Lambda);
static int64_t lambda_number = 1;
CORD name = CORD_asprintf("lambda$%ld", lambda_number++);
env_t *body_scope = fresh_scope(env);
body_scope->locals->fallback = env->globals;
for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) {
type_t *arg_type = get_arg_ast_type(env, arg);
set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=arg->name));
}
type_t *ret_t = get_type(body_scope, lambda->body);
CORD code = CORD_all("static ", compile_type(ret_t), " ", name, "(");
for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) {
type_t *arg_type = get_arg_ast_type(env, arg);
code = CORD_all(code, compile_type(arg_type), " ", arg->name, ", ");
}
code = CORD_cat(code, "void *$userdata)");
CORD body = compile(body_scope, lambda->body);
if (CORD_fetch(body, 0) != '{')
body = CORD_all("{\n", body, "\n}");
env->code->funcs = CORD_all(env->code->funcs, code, " ", body);
return CORD_all("(closure_t){", name, ", NULL}");
}
case MethodCall: {
auto call = Match(ast, MethodCall);
type_t *self_t = get_type(env, call->self);
@ -771,38 +819,43 @@ CORD compile(env_t *env, ast_t *ast)
return CORD_all("Table_clear(", self, ")");
} else code_err(ast, "There is no '%s' method for tables", call->name);
}
default: goto fncall;
default: {
auto call = Match(ast, MethodCall);
type_t *fn_t = get_method_type(env, call->self, call->name);
arg_ast_t *args = new(arg_ast_t, .value=call->self, .next=call->args);
binding_t *b = get_namespace_binding(env, call->self, call->name);
if (!b) code_err(ast, "No such method");
return CORD_all(b->code, "(", compile_arguments(env, ast, Match(fn_t, FunctionType)->args, args), ")");
}
}
}
case FunctionCall: {
fncall:;
type_t *fn_t;
arg_ast_t *args;
CORD fn;
if (ast->tag == FunctionCall) {
auto call = Match(ast, FunctionCall);
fn_t = get_type(env, call->fn);
if (fn_t->tag == TypeInfoType) {
type_t *t = Match(fn_t, TypeInfoType)->type;
if (!(t->tag == StructType))
code_err(call->fn, "This is not a type that has a constructor");
fn_t = Type(FunctionType, .args=Match(t, StructType)->fields, .ret=t);
} else if (fn_t->tag != FunctionType) {
code_err(call->fn, "This is not a function, it's a %T", fn_t);
auto call = Match(ast, FunctionCall);
type_t *fn_t = get_type(env, call->fn);
if (fn_t->tag == FunctionType) {
CORD fn = compile(env, call->fn);
return CORD_all(fn, "(", compile_arguments(env, ast, Match(fn_t, FunctionType)->args, call->args), ")");
} else if (fn_t->tag == TypeInfoType) {
type_t *t = Match(fn_t, TypeInfoType)->type;
if (!(t->tag == StructType))
code_err(call->fn, "This is not a type that has a constructor");
fn_t = Type(FunctionType, .args=Match(t, StructType)->fields, .ret=t);
CORD fn = compile(env, call->fn);
return CORD_all(fn, "(", compile_arguments(env, ast, Match(fn_t, FunctionType)->args, call->args), ")");
} else if (fn_t->tag == ClosureType) {
fn_t = Match(fn_t, ClosureType)->fn;
arg_t *type_args = Match(fn_t, FunctionType)->args;
CORD fn_type_code = compile_type(fn_t);
CORD closure = compile(env, call->fn);
if (call->fn->tag == Var) {
return CORD_all("((", fn_type_code, ")", closure, ".fn)(", compile_arguments(env, ast, type_args, call->args), ")");
} else {
return CORD_all("({ closure_t $closure = ", closure, "; ((", fn_type_code, ")$closure.fn)(",
compile_arguments(env, ast, type_args, call->args), "); })");
}
args = call->args;
fn = compile(env, call->fn);
} else {
auto method = Match(ast, MethodCall);
fn_t = get_method_type(env, method->self, method->name);
args = new(arg_ast_t, .value=method->self, .next=method->args);
binding_t *b = get_namespace_binding(env, method->self, method->name);
if (!b) code_err(ast, "No such method");
fn = b->code;
code_err(call->fn, "This is not a function, it's a %T", fn_t);
}
CORD code = CORD_all(fn, "(", compile_arguments(env, ast, Match(fn_t, FunctionType)->args, args), ")");
return code;
}
case If: {
auto if_ = Match(ast, If);
@ -1212,15 +1265,12 @@ CORD compile(env_t *env, ast_t *ast)
default: code_err(ast, "Indexing is not supported for type: %T", container_t);
}
}
// Use,
// LinkerDirective,
case InlineCCode: return Match(ast, InlineCCode)->code;
case Unknown: code_err(ast, "Unknown AST");
case Lambda: code_err(ast, "Lambdas are not supported yet");
case Use: code_err(ast, "Uses are not supported yet");
case LinkerDirective: code_err(ast, "Linker directives are not supported yet");
case Extern: code_err(ast, "Externs are not supported yet");
case TableEntry: code_err(ast, "Table entries should not be compiled directly");
case Unknown: code_err(ast, "Unknown AST");
}
code_err(ast, "Unknown AST: %W", ast);
return NULL;
@ -1284,7 +1334,7 @@ CORD compile_type_info(env_t *env, type_t *t)
return CORD_asprintf("$FunctionInfo(%r)", Text__quoted(type_to_cord(t), false));
}
case ClosureType: {
errx(1, "No typeinfo for closures yet");
return CORD_asprintf("$ClosureInfo(%r)", Text__quoted(type_to_cord(t), false));
}
case TypeInfoType: return "&TypeInfo_info";
default:

View File

@ -15,6 +15,7 @@ typedef struct {
CORD expr_as_texting(env_t *env, CORD expr, type_t *t, CORD color);
module_code_t compile_file(ast_t *ast);
CORD compile_type_ast(type_ast_t *t);
CORD compile_declaration(type_t *t, const char *name);
CORD compile_type(type_t *t);
CORD compile(env_t *env, ast_t *ast);
void compile_namespace(env_t *env, const char *ns_name, ast_t *block);

View File

@ -114,9 +114,7 @@ void compile_enum_def(env_t *env, ast_t *ast)
CORD arg_sig = CORD_EMPTY;
for (arg_ast_t *field = tag->fields; field; field = field->next) {
type_t *field_t = get_arg_ast_type(env, field);
CORD type_code = compile_type(field_t);
arg_sig = CORD_all(arg_sig, type_code, " ", field->name);
if (CORD_cmp(type_code, "Bool_t") == 0) arg_sig = CORD_cat(arg_sig, ":1");
arg_sig = CORD_all(arg_sig, compile_declaration(field_t, field->name));
if (field->next) arg_sig = CORD_cat(arg_sig, ", ");
}
if (arg_sig == CORD_EMPTY) arg_sig = "void";

View File

@ -396,6 +396,8 @@ type_t *get_type(env_t *env, ast_t *ast)
return t; // Constructor
code_err(call->fn, "This is not a type that has a constructor");
}
if (fn_type_t->tag == ClosureType)
fn_type_t = Match(fn_type_t, ClosureType)->fn;
if (fn_type_t->tag != FunctionType)
code_err(call->fn, "This isn't a function, it's a %T", fn_type_t);
auto fn_type = Match(fn_type_t, FunctionType);

View File

@ -27,6 +27,9 @@ CORD type_to_cord(type_t *t) {
auto table = Match(t, TableType);
return CORD_asprintf("{%r=>%r}", type_to_cord(table->key_type), type_to_cord(table->value_type));
}
case ClosureType: {
return CORD_all("~", type_to_cord(Match(t, ClosureType)->fn));
}
case FunctionType: {
CORD c = "func(";
auto fn = Match(t, FunctionType);
@ -440,7 +443,7 @@ size_t type_align(type_t *t)
case ArrayType: return __alignof__(array_t);
case TableType: return __alignof__(table_t);
case FunctionType: return __alignof__(void*);
case ClosureType: return __alignof__(void*);
case ClosureType: return __alignof__(struct {void *fn, *userdata;});
case PointerType: return __alignof__(void*);
case StructType: {
size_t align = 0;