From 955f047e069497be4cbeffa3e0309360aeb1efa7 Mon Sep 17 00:00:00 2001 From: Bruce Hill Date: Sat, 9 Mar 2024 14:02:19 -0500 Subject: [PATCH] First pass at lambdas/closures --- builtins/datatypes.h | 4 ++ builtins/types.h | 2 + compile.c | 122 ++++++++++++++++++++++++++++++------------- compile.h | 1 + enums.c | 4 +- typecheck.c | 2 + types.c | 5 +- 7 files changed, 100 insertions(+), 40 deletions(-) diff --git a/builtins/datatypes.h b/builtins/datatypes.h index a9c9ac2..20dede3 100644 --- a/builtins/datatypes.h +++ b/builtins/datatypes.h @@ -30,4 +30,8 @@ typedef struct table_s { void *default_value; } table_t; +typedef struct { + void *fn, *userdata; +} closure_t; + // vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0 diff --git a/builtins/types.h b/builtins/types.h index 397e14c..533ffb8 100644 --- a/builtins/types.h +++ b/builtins/types.h @@ -52,6 +52,8 @@ typedef struct TypeInfo { .tag=TableInfo, .TableInfo.key=key_expr, .TableInfo.value=value_expr}) #define $FunctionInfo(typestr) &((TypeInfo){.size=sizeof(void*), .align=__alignof__(void*), \ .tag=FunctionInfo, .FunctionInfo.type_str=typestr}) +#define $ClosureInfo(typestr) &((TypeInfo){.size=2*sizeof(void*), .align=__alignof__(void*), \ + .tag=FunctionInfo, .FunctionInfo.type_str=typestr}) #define $TypeInfoInfo(typestr) &((TypeInfo){.size=sizeof(TypeInfo), .align=__alignof__(TypeInfo), \ .tag=TypeInfoInfo, .TypeInfoInfo.type_str=typestr}) diff --git a/compile.c b/compile.c index 29feffa..3e49cf5 100644 --- a/compile.c +++ b/compile.c @@ -26,6 +26,21 @@ CORD compile_type_ast(type_ast_t *t) } } +CORD compile_declaration(type_t *t, const char *name) +{ + if (t->tag == FunctionType) { + auto fn = Match(t, FunctionType); + CORD code = CORD_all(compile_type(fn->ret), " (*", name, ")("); + for (arg_t *arg = fn->args; arg; arg = arg->next) { + code = CORD_all(code, compile_type(arg->type)); + if (arg->next) code = CORD_cat(code, ", "); + } + return CORD_all(code, ")"); + } else { + return CORD_all(compile_type(t), " ", name); + } +} + CORD compile_type(type_t *t) { switch (t->tag) { @@ -41,8 +56,16 @@ CORD compile_type(type_t *t) } case ArrayType: return "array_t"; case TableType: return "table_t"; - case FunctionType: return "const void*"; - case ClosureType: compiler_err(NULL, NULL, NULL, "Not implemented"); + case FunctionType: { + auto fn = Match(t, FunctionType); + CORD code = CORD_all(compile_type(fn->ret), " (*)("); + for (arg_t *arg = fn->args; arg; arg = arg->next) { + code = CORD_all(code, compile_type(arg->type)); + if (arg->next) code = CORD_cat(code, ", "); + } + return CORD_all(code, ")"); + } + case ClosureType: return "closure_t"; case PointerType: return CORD_cat(compile_type(Match(t, PointerType)->pointed), "*"); case StructType: return CORD_cat(Match(t, StructType)->name, "_t"); case EnumType: return CORD_cat(Match(t, EnumType)->name, "_t"); @@ -560,8 +583,7 @@ CORD compile(env_t *env, ast_t *ast) case Declare: { auto decl = Match(ast, Declare); type_t *t = get_type(env, decl->value); - // return CORD_asprintf("auto %r = %r;", compile(env, decl->var), compile(env, decl->value)); - return CORD_asprintf("%r %r = %r;", compile_type(t), compile(env, decl->var), compile(env, decl->value)); + return CORD_all(compile_declaration(t, Match(decl->var, Var)->name), " = ", compile(env, decl->value), ";"); } case Assign: { auto assign = Match(ast, Assign); @@ -673,7 +695,7 @@ CORD compile(env_t *env, ast_t *ast) CORD signature = CORD_all(fndef->ret_type ? compile_type_ast(fndef->ret_type) : "void", " ", name, "("); for (arg_ast_t *arg = fndef->args; arg; arg = arg->next) { type_t *arg_type = get_arg_ast_type(env, arg); - CORD_appendf(&signature, "%r %s", compile_type(arg_type), arg->name); + signature = CORD_cat(signature, compile_declaration(arg_type, arg->name)); if (arg->next) signature = CORD_cat(signature, ", "); } signature = CORD_cat(signature, ")"); @@ -701,6 +723,32 @@ CORD compile(env_t *env, ast_t *ast) env->code->funcs = CORD_all(env->code->funcs, code, " ", body); return CORD_EMPTY; } + case Lambda: { + auto lambda = Match(ast, Lambda); + static int64_t lambda_number = 1; + CORD name = CORD_asprintf("lambda$%ld", lambda_number++); + + env_t *body_scope = fresh_scope(env); + body_scope->locals->fallback = env->globals; + for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) { + type_t *arg_type = get_arg_ast_type(env, arg); + set_binding(body_scope, arg->name, new(binding_t, .type=arg_type, .code=arg->name)); + } + type_t *ret_t = get_type(body_scope, lambda->body); + + CORD code = CORD_all("static ", compile_type(ret_t), " ", name, "("); + for (arg_ast_t *arg = lambda->args; arg; arg = arg->next) { + type_t *arg_type = get_arg_ast_type(env, arg); + code = CORD_all(code, compile_type(arg_type), " ", arg->name, ", "); + } + code = CORD_cat(code, "void *$userdata)"); + + CORD body = compile(body_scope, lambda->body); + if (CORD_fetch(body, 0) != '{') + body = CORD_all("{\n", body, "\n}"); + env->code->funcs = CORD_all(env->code->funcs, code, " ", body); + return CORD_all("(closure_t){", name, ", NULL}"); + } case MethodCall: { auto call = Match(ast, MethodCall); type_t *self_t = get_type(env, call->self); @@ -771,38 +819,43 @@ CORD compile(env_t *env, ast_t *ast) return CORD_all("Table_clear(", self, ")"); } else code_err(ast, "There is no '%s' method for tables", call->name); } - default: goto fncall; + default: { + auto call = Match(ast, MethodCall); + type_t *fn_t = get_method_type(env, call->self, call->name); + arg_ast_t *args = new(arg_ast_t, .value=call->self, .next=call->args); + binding_t *b = get_namespace_binding(env, call->self, call->name); + if (!b) code_err(ast, "No such method"); + return CORD_all(b->code, "(", compile_arguments(env, ast, Match(fn_t, FunctionType)->args, args), ")"); + } } } case FunctionCall: { - fncall:; - type_t *fn_t; - arg_ast_t *args; - CORD fn; - if (ast->tag == FunctionCall) { - auto call = Match(ast, FunctionCall); - fn_t = get_type(env, call->fn); - if (fn_t->tag == TypeInfoType) { - type_t *t = Match(fn_t, TypeInfoType)->type; - if (!(t->tag == StructType)) - code_err(call->fn, "This is not a type that has a constructor"); - fn_t = Type(FunctionType, .args=Match(t, StructType)->fields, .ret=t); - } else if (fn_t->tag != FunctionType) { - code_err(call->fn, "This is not a function, it's a %T", fn_t); + auto call = Match(ast, FunctionCall); + type_t *fn_t = get_type(env, call->fn); + if (fn_t->tag == FunctionType) { + CORD fn = compile(env, call->fn); + return CORD_all(fn, "(", compile_arguments(env, ast, Match(fn_t, FunctionType)->args, call->args), ")"); + } else if (fn_t->tag == TypeInfoType) { + type_t *t = Match(fn_t, TypeInfoType)->type; + if (!(t->tag == StructType)) + code_err(call->fn, "This is not a type that has a constructor"); + fn_t = Type(FunctionType, .args=Match(t, StructType)->fields, .ret=t); + CORD fn = compile(env, call->fn); + return CORD_all(fn, "(", compile_arguments(env, ast, Match(fn_t, FunctionType)->args, call->args), ")"); + } else if (fn_t->tag == ClosureType) { + fn_t = Match(fn_t, ClosureType)->fn; + arg_t *type_args = Match(fn_t, FunctionType)->args; + CORD fn_type_code = compile_type(fn_t); + CORD closure = compile(env, call->fn); + if (call->fn->tag == Var) { + return CORD_all("((", fn_type_code, ")", closure, ".fn)(", compile_arguments(env, ast, type_args, call->args), ")"); + } else { + return CORD_all("({ closure_t $closure = ", closure, "; ((", fn_type_code, ")$closure.fn)(", + compile_arguments(env, ast, type_args, call->args), "); })"); } - args = call->args; - fn = compile(env, call->fn); } else { - auto method = Match(ast, MethodCall); - fn_t = get_method_type(env, method->self, method->name); - args = new(arg_ast_t, .value=method->self, .next=method->args); - binding_t *b = get_namespace_binding(env, method->self, method->name); - if (!b) code_err(ast, "No such method"); - fn = b->code; + code_err(call->fn, "This is not a function, it's a %T", fn_t); } - - CORD code = CORD_all(fn, "(", compile_arguments(env, ast, Match(fn_t, FunctionType)->args, args), ")"); - return code; } case If: { auto if_ = Match(ast, If); @@ -1212,15 +1265,12 @@ CORD compile(env_t *env, ast_t *ast) default: code_err(ast, "Indexing is not supported for type: %T", container_t); } } - // Use, - // LinkerDirective, case InlineCCode: return Match(ast, InlineCCode)->code; - case Unknown: code_err(ast, "Unknown AST"); - case Lambda: code_err(ast, "Lambdas are not supported yet"); case Use: code_err(ast, "Uses are not supported yet"); case LinkerDirective: code_err(ast, "Linker directives are not supported yet"); case Extern: code_err(ast, "Externs are not supported yet"); case TableEntry: code_err(ast, "Table entries should not be compiled directly"); + case Unknown: code_err(ast, "Unknown AST"); } code_err(ast, "Unknown AST: %W", ast); return NULL; @@ -1284,7 +1334,7 @@ CORD compile_type_info(env_t *env, type_t *t) return CORD_asprintf("$FunctionInfo(%r)", Text__quoted(type_to_cord(t), false)); } case ClosureType: { - errx(1, "No typeinfo for closures yet"); + return CORD_asprintf("$ClosureInfo(%r)", Text__quoted(type_to_cord(t), false)); } case TypeInfoType: return "&TypeInfo_info"; default: diff --git a/compile.h b/compile.h index 6ae5cb0..1f385bb 100644 --- a/compile.h +++ b/compile.h @@ -15,6 +15,7 @@ typedef struct { CORD expr_as_texting(env_t *env, CORD expr, type_t *t, CORD color); module_code_t compile_file(ast_t *ast); CORD compile_type_ast(type_ast_t *t); +CORD compile_declaration(type_t *t, const char *name); CORD compile_type(type_t *t); CORD compile(env_t *env, ast_t *ast); void compile_namespace(env_t *env, const char *ns_name, ast_t *block); diff --git a/enums.c b/enums.c index 3e4a586..cb55f54 100644 --- a/enums.c +++ b/enums.c @@ -114,9 +114,7 @@ void compile_enum_def(env_t *env, ast_t *ast) CORD arg_sig = CORD_EMPTY; for (arg_ast_t *field = tag->fields; field; field = field->next) { type_t *field_t = get_arg_ast_type(env, field); - CORD type_code = compile_type(field_t); - arg_sig = CORD_all(arg_sig, type_code, " ", field->name); - if (CORD_cmp(type_code, "Bool_t") == 0) arg_sig = CORD_cat(arg_sig, ":1"); + arg_sig = CORD_all(arg_sig, compile_declaration(field_t, field->name)); if (field->next) arg_sig = CORD_cat(arg_sig, ", "); } if (arg_sig == CORD_EMPTY) arg_sig = "void"; diff --git a/typecheck.c b/typecheck.c index 10f2f1d..88b9d89 100644 --- a/typecheck.c +++ b/typecheck.c @@ -396,6 +396,8 @@ type_t *get_type(env_t *env, ast_t *ast) return t; // Constructor code_err(call->fn, "This is not a type that has a constructor"); } + if (fn_type_t->tag == ClosureType) + fn_type_t = Match(fn_type_t, ClosureType)->fn; if (fn_type_t->tag != FunctionType) code_err(call->fn, "This isn't a function, it's a %T", fn_type_t); auto fn_type = Match(fn_type_t, FunctionType); diff --git a/types.c b/types.c index e1f3026..0b08cf9 100644 --- a/types.c +++ b/types.c @@ -27,6 +27,9 @@ CORD type_to_cord(type_t *t) { auto table = Match(t, TableType); return CORD_asprintf("{%r=>%r}", type_to_cord(table->key_type), type_to_cord(table->value_type)); } + case ClosureType: { + return CORD_all("~", type_to_cord(Match(t, ClosureType)->fn)); + } case FunctionType: { CORD c = "func("; auto fn = Match(t, FunctionType); @@ -440,7 +443,7 @@ size_t type_align(type_t *t) case ArrayType: return __alignof__(array_t); case TableType: return __alignof__(table_t); case FunctionType: return __alignof__(void*); - case ClosureType: return __alignof__(void*); + case ClosureType: return __alignof__(struct {void *fn, *userdata;}); case PointerType: return __alignof__(void*); case StructType: { size_t align = 0;