From 39dd1ca27da9e9d88ee59565df99ee281e1b3632 Mon Sep 17 00:00:00 2001 From: Bruce Hill Date: Mon, 10 Mar 2025 12:42:45 -0400 Subject: [PATCH] Add `convert` keyword for defining conversions --- ast.c | 2 + ast.h | 9 ++- compile.c | 180 ++++++++++++++++++++++++++++---------------------- docs/langs.md | 23 ++++++- environment.c | 15 +++-- environment.h | 3 +- parse.c | 39 +++++++++++ test/lang.tm | 6 +- typecheck.c | 29 ++++++-- types.c | 10 +++ types.h | 1 + 11 files changed, 218 insertions(+), 99 deletions(-) diff --git a/ast.c b/ast.c index aaf9e1a..367b118 100644 --- a/ast.c +++ b/ast.c @@ -134,6 +134,8 @@ CORD ast_to_xml(ast_t *ast) optional_tagged("filter", data.filter)) T(FunctionDef, "%r%r%r", ast_to_xml(data.name), arg_list_to_xml(data.args), optional_tagged_type("return-type", data.ret_type), ast_to_xml(data.body)) + T(ConvertDef, "%r%r%r", + arg_list_to_xml(data.args), optional_tagged_type("return-type", data.ret_type), ast_to_xml(data.body)) T(Lambda, "%r%r%r)", arg_list_to_xml(data.args), optional_tagged_type("return-type", data.ret_type), ast_to_xml(data.body)) T(FunctionCall, "%r%r", ast_to_xml(data.fn), arg_list_to_xml(data.args)) diff --git a/ast.h b/ast.h index 2e39812..5ae03bb 100644 --- a/ast.h +++ b/ast.h @@ -129,7 +129,7 @@ typedef enum { Not, Negative, HeapAllocate, StackReference, Mutexed, Holding, Min, Max, Array, Set, Table, TableEntry, Comprehension, - FunctionDef, Lambda, + FunctionDef, Lambda, ConvertDef, FunctionCall, MethodCall, Block, For, While, If, When, Repeat, @@ -228,6 +228,13 @@ struct ast_s { ast_t *cache; bool is_inline; } FunctionDef; + struct { + arg_ast_t *args; + type_ast_t *ret_type; + ast_t *body; + ast_t *cache; + bool is_inline; + } ConvertDef; struct { arg_ast_t *args; type_ast_t *ret_type; diff --git a/compile.c b/compile.c index 1938800..2cd2869 100644 --- a/compile.c +++ b/compile.c @@ -106,9 +106,8 @@ static bool promote(env_t *env, ast_t *ast, CORD *code, type_t *actual, type_t * // Numeric promotions/demotions if ((is_numeric_type(actual) || actual->tag == BoolType) && (is_numeric_type(needed) || needed->tag == BoolType)) { arg_ast_t *args = new(arg_ast_t, .value=FakeAST(InlineCCode, .code=*code, .type=actual)); - binding_t *constructor = NULL; - if ((constructor=get_constructor(env, needed, args, needed)) - || (constructor=get_constructor(env, actual, args, needed))) { + binding_t *constructor = get_constructor(env, needed, args); + if (constructor) { auto fn = Match(constructor->type, FunctionType); if (fn->args->next == NULL) { *code = CORD_all(constructor->code, "(", compile_arguments(env, ast, fn->args, args), ")"); @@ -129,7 +128,7 @@ static bool promote(env_t *env, ast_t *ast, CORD *code, type_t *actual, type_t * } // Text to C String - if (actual->tag == TextType && !Match(actual, TextType)->lang && needed->tag == CStringType) { + if (actual->tag == TextType && type_eq(actual, TEXT_TYPE) && needed->tag == CStringType) { *code = CORD_all("Text$as_c_string(", *code, ")"); return true; } @@ -443,7 +442,7 @@ static void add_closed_vars(Table_t *closed_vars, env_t *enclosing_scope, env_t add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Deserialize)->value); break; } - case Use: case FunctionDef: case StructDef: case EnumDef: case LangDef: { + case Use: case FunctionDef: case ConvertDef: case StructDef: case EnumDef: case LangDef: { errx(1, "Definitions should not be reachable in a closure."); } default: @@ -1186,16 +1185,7 @@ static CORD _compile_statement(env_t *env, ast_t *ast) default: code_err(ast, "Update assignments are not implemented for this operation"); } } - case StructDef: { - return CORD_EMPTY; - } - case EnumDef: { - return CORD_EMPTY; - } - case LangDef: { - return CORD_EMPTY; - } - case FunctionDef: { + case StructDef: case EnumDef: case LangDef: case FunctionDef: case ConvertDef: { return CORD_EMPTY; } case Skip: { @@ -2713,29 +2703,19 @@ CORD compile(env_t *env, ast_t *ast) for (ast_list_t *chunk = chunks; chunk; chunk = chunk->next) { CORD chunk_code; type_t *chunk_t = get_type(env, chunk->ast); - if (chunk->ast->tag == TextLiteral) { + if (chunk->ast->tag == TextLiteral || type_eq(chunk_t, text_t)) { chunk_code = compile(env, chunk->ast); - } else if (chunk_t->tag == TextType && streq(Match(chunk_t, TextType)->lang, lang)) { // Interp is same type as text literal - binding_t *constructor = get_constructor(env, text_t, new(arg_ast_t, .value=chunk->ast), text_t); + } else { + binding_t *constructor = get_constructor(env, text_t, new(arg_ast_t, .value=chunk->ast)); if (constructor) { arg_t *arg_spec = Match(constructor->type, FunctionType)->args; arg_ast_t *args = new(arg_ast_t, .value=chunk->ast); chunk_code = CORD_all(constructor->code, "(", compile_arguments(env, ast, arg_spec, args), ")"); + } else if (type_eq(text_t, TEXT_TYPE)) { + chunk_code = compile_string(env, chunk->ast, "no"); } else { - chunk_code = compile(env, chunk->ast); - } - } else if (lang) { // Interp is different type from text literal (which is a DSL) - binding_t *constructor = get_constructor(env, text_t, new(arg_ast_t, .value=chunk->ast), text_t); - if (!constructor) - constructor = get_constructor(env, chunk_t, new(arg_ast_t, .value=chunk->ast), text_t); - if (!constructor) code_err(chunk->ast, "I don't know how to convert %T to %T", chunk_t, text_t); - - arg_t *arg_spec = Match(constructor->type, FunctionType)->args; - arg_ast_t *args = new(arg_ast_t, .value=chunk->ast); - chunk_code = CORD_all(constructor->code, "(", compile_arguments(env, ast, arg_spec, args), ")"); - } else { - chunk_code = compile_string(env, chunk->ast, "no"); + } } code = CORD_cat(code, chunk_code); if (chunk->next) code = CORD_cat(code, ", "); @@ -3411,7 +3391,7 @@ CORD compile(env_t *env, ast_t *ast) else if (t->tag == NumType && call->args && !call->args->next && call->args->value->tag == Num) return compile_to_type(env, call->args->value, t); - binding_t *constructor = get_constructor(env, t, call->args, t); + binding_t *constructor = get_constructor(env, t, call->args); if (constructor) { arg_t *arg_spec = Match(constructor->type, FunctionType)->args; return CORD_all(constructor->code, "(", compile_arguments(env, ast, arg_spec, call->args), ")"); @@ -3420,8 +3400,7 @@ CORD compile(env_t *env, ast_t *ast) type_t *actual = call->args ? get_type(env, call->args->value) : NULL; if (t->tag == TextType) { if (!call->args) code_err(ast, "This constructor needs a value"); - const char *lang = Match(t, TextType)->lang; - if (lang) + if (!type_eq(t, TEXT_TYPE)) code_err(call->fn, "I don't have a constructor defined for these arguments"); // Text constructor: if (!call->args || call->args->next) @@ -3441,7 +3420,7 @@ CORD compile(env_t *env, ast_t *ast) return compile_string_literal(Match(Match(call->args->value, TextJoin)->children->ast, TextLiteral)->cord); return CORD_all("Text$as_c_string(", expr_as_text(env, compile(env, call->args->value), actual, "no"), ")"); } else { - code_err(call->fn, "This is not a type that has a constructor"); + code_err(ast, "I could not find a constructor matching these arguments for %T", t); } } else if (fn_t->tag == ClosureType) { fn_t = Match(fn_t, ClosureType)->fn; @@ -3866,7 +3845,7 @@ CORD compile(env_t *env, ast_t *ast) case Extern: code_err(ast, "Externs are not supported as expressions"); case TableEntry: code_err(ast, "Table entries should not be compiled directly"); case Declare: case Assign: case UpdateAssign: case For: case While: case Repeat: case StructDef: case LangDef: - case EnumDef: case FunctionDef: case Skip: case Stop: case Pass: case Return: case DocTest: case PrintStatement: + case EnumDef: case FunctionDef: case ConvertDef: case Skip: case Stop: case Pass: case Return: case DocTest: case PrintStatement: code_err(ast, "This is not a valid expression"); default: case Unknown: code_err(ast, "Unknown AST"); } @@ -3883,7 +3862,7 @@ CORD compile_type_info(env_t *env, type_t *t) return CORD_all("&", type_to_cord(t), "$info"); case TextType: { auto text = Match(t, TextType); - if (!text->lang) + if (!text->lang || streq(text->lang, "Text")) return "&Text$info"; else if (streq(text->lang, "Pattern")) return "&Pattern$info"; @@ -4056,22 +4035,40 @@ CORD compile_cli_arg_call(env_t *env, CORD fn_name, type_t *fn_type) return code; } -CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs) +CORD compile_function(env_t *env, CORD name_code, ast_t *ast, CORD *staticdefs) { - auto fndef = Match(ast, FunctionDef); - const char *raw_name = Match(fndef->name, Var)->name; - bool is_private = raw_name[0] == '_'; - CORD name_code = CORD_all(namespace_prefix(env, env->namespace), raw_name); - binding_t *clobbered = get_binding(env, raw_name); - type_t *ret_t = fndef->ret_type ? parse_type_ast(env, fndef->ret_type) : Type(VoidType); - // Check for a constructor: - if (clobbered && clobbered->type->tag == TypeInfoType && type_eq(ret_t, Match(clobbered->type, TypeInfoType)->type)) { - name_code = CORD_asprintf("%r$%ld", name_code, get_line_number(ast->file, ast->start)); + bool is_private = false, is_main = false; + const char *function_name; + arg_ast_t *args; + type_t *ret_t; + ast_t *body; + ast_t *cache; + bool is_inline; + if (ast->tag == FunctionDef) { + auto fndef = Match(ast, FunctionDef); + function_name = Match(fndef->name, Var)->name; + is_private = function_name[0] == '_'; + is_main = streq(function_name, "main"); + args = fndef->args; + ret_t = fndef->ret_type ? parse_type_ast(env, fndef->ret_type) : Type(VoidType); + body = fndef->body; + cache = fndef->cache; + is_inline = fndef->is_inline; + } else { + auto convertdef = Match(ast, ConvertDef); + args = convertdef->args; + ret_t = convertdef->ret_type ? parse_type_ast(env, convertdef->ret_type) : Type(VoidType); + function_name = get_type_name(ret_t); + if (!function_name) + code_err(ast, "Conversions are only supported for text, struct, and enum types, not %T", ret_t); + body = convertdef->body; + cache = convertdef->cache; + is_inline = convertdef->is_inline; } CORD arg_signature = "("; Table_t used_names = {}; - for (arg_ast_t *arg = fndef->args; arg; arg = arg->next) { + for (arg_ast_t *arg = args; arg; arg = arg->next) { type_t *arg_type = get_arg_ast_type(env, arg); arg_signature = CORD_cat(arg_signature, compile_declaration(arg_type, CORD_cat("_$", arg->name))); if (arg->next) arg_signature = CORD_cat(arg_signature, ", "); @@ -4089,11 +4086,11 @@ CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs) *staticdefs = CORD_all(*staticdefs, "static ", ret_type_code, " ", name_code, arg_signature, ";\n"); CORD code; - if (fndef->cache) { + if (cache) { code = CORD_all("static ", ret_type_code, " ", name_code, "$uncached", arg_signature); } else { code = CORD_all(ret_type_code, " ", name_code, arg_signature); - if (fndef->is_inline) + if (is_inline) code = CORD_cat("INLINE ", code); if (!is_private) code = CORD_cat("public ", code); @@ -4102,14 +4099,14 @@ CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs) env_t *body_scope = fresh_scope(env); body_scope->deferred = NULL; body_scope->namespace = NULL; - for (arg_ast_t *arg = fndef->args; arg; arg = arg->next) { + for (arg_ast_t *arg = args; arg; arg = arg->next) { type_t *arg_type = get_arg_ast_type(env, arg); set_binding(body_scope, arg->name, arg_type, CORD_cat("_$", arg->name)); } body_scope->fn_ret = ret_t; - type_t *body_type = get_type(body_scope, fndef->body); + type_t *body_type = get_type(body_scope, body); if (ret_t->tag == AbortType) { if (body_type->tag != AbortType) code_err(ast, "This function can reach the end without aborting!"); @@ -4121,15 +4118,15 @@ CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs) code_err(ast, "This function can reach the end without returning a %T value!", ret_t); } - CORD body = compile_statement(body_scope, fndef->body); - if (streq(raw_name, "main")) - body = CORD_all("_$", env->namespace->name, "$$initialize();\n", body); - if (CORD_fetch(body, 0) != '{') - body = CORD_asprintf("{\n%r\n}", body); + CORD body_code = compile_statement(body_scope, body); + if (is_main) + body_code = CORD_all("_$", env->namespace->name, "$$initialize();\n", body_code); + if (CORD_fetch(body_code, 0) != '{') + body_code = CORD_asprintf("{\n%r\n}", body_code); - CORD definition = with_source_info(ast, CORD_all(code, " ", body, "\n")); + CORD definition = with_source_info(ast, CORD_all(code, " ", body_code, "\n")); - if (fndef->cache && fndef->args == NULL) { // no-args cache just uses a static var + if (cache && args == NULL) { // no-args cache just uses a static var CORD wrapper = CORD_all( is_private ? CORD_EMPTY : "public ", ret_type_code, " ", name_code, "(void) {\n" "static ", compile_declaration(ret_t, "cached_result"), ";\n", @@ -4141,45 +4138,44 @@ CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs) "return cached_result;\n" "}\n"); definition = CORD_cat(definition, wrapper); - } else if (fndef->cache && fndef->cache->tag == Int) { - assert(fndef->args); - OptionalInt64_t cache_size = Int64$parse(Text$from_str(Match(fndef->cache, Int)->str)); + } else if (cache && cache->tag == Int) { + assert(args); + OptionalInt64_t cache_size = Int64$parse(Text$from_str(Match(cache, Int)->str)); CORD pop_code = CORD_EMPTY; - if (fndef->cache->tag == Int && !cache_size.is_none && cache_size.i > 0) { + if (cache->tag == Int && !cache_size.is_none && cache_size.i > 0) { pop_code = CORD_all("if (cache.entries.length > ", CORD_asprintf("%ld", cache_size.i), ") Table$remove(&cache, cache.entries.data + cache.entries.stride*RNG$int64(default_rng, 0, cache.entries.length-1), table_type);\n"); } - if (!fndef->args->next) { + if (!args->next) { // Single-argument functions have simplified caching logic - type_t *arg_type = get_arg_ast_type(env, fndef->args); + type_t *arg_type = get_arg_ast_type(env, args); CORD wrapper = CORD_all( is_private ? CORD_EMPTY : "public ", ret_type_code, " ", name_code, arg_signature, "{\n" "static Table_t cache = {};\n", "const TypeInfo_t *table_type = Table$info(", compile_type_info(env, arg_type), ", ", compile_type_info(env, ret_t), ");\n", - compile_declaration(Type(PointerType, .pointed=ret_t), "cached"), " = Table$get_raw(cache, &_$", fndef->args->name, ", table_type);\n" + compile_declaration(Type(PointerType, .pointed=ret_t), "cached"), " = Table$get_raw(cache, &_$", args->name, ", table_type);\n" "if (cached) return *cached;\n", - compile_declaration(ret_t, "ret"), " = ", name_code, "$uncached(_$", fndef->args->name, ");\n", + compile_declaration(ret_t, "ret"), " = ", name_code, "$uncached(_$", args->name, ");\n", pop_code, - "Table$set(&cache, &_$", fndef->args->name, ", &ret, table_type);\n" + "Table$set(&cache, &_$", args->name, ", &ret, table_type);\n" "return ret;\n" "}\n"); definition = CORD_cat(definition, wrapper); } else { // Multi-argument functions use a custom struct type (only defined internally) as a cache key: arg_t *fields = NULL; - for (arg_ast_t *arg = fndef->args; arg; arg = arg->next) + for (arg_ast_t *arg = args; arg; arg = arg->next) fields = new(arg_t, .name=arg->name, .type=get_arg_ast_type(env, arg), .next=fields); REVERSE_LIST(fields); - type_t *t = Type(StructType, .name=heap_strf("%s$args", fndef->name), .fields=fields, .env=env); + type_t *t = Type(StructType, .name=heap_strf("func$%ld$args", get_line_number(ast->file, ast->start)), .fields=fields, .env=env); int64_t num_fields = used_names.entries.length; - const char *short_name = raw_name; const char *metamethods = is_packed_data(t) ? "PackedData$metamethods" : "Struct$metamethods"; CORD args_typeinfo = CORD_asprintf("((TypeInfo_t[1]){{.size=%zu, .align=%zu, .metamethods=%s, " - ".tag=StructInfo, .StructInfo.name=\"%s\", " + ".tag=StructInfo, .StructInfo.name=\"FunctionArguments\", " ".StructInfo.num_fields=%ld, .StructInfo.fields=(NamedType_t[%ld]){", - type_size(t), type_align(t), metamethods, short_name, + type_size(t), type_align(t), metamethods, num_fields, num_fields); CORD args_type = "struct { "; for (arg_t *f = fields; f; f = f->next) { @@ -4191,7 +4187,7 @@ CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs) args_typeinfo = CORD_all(args_typeinfo, "}}})"); CORD all_args = CORD_EMPTY; - for (arg_ast_t *arg = fndef->args; arg; arg = arg->next) + for (arg_ast_t *arg = args; arg; arg = arg->next) all_args = CORD_all(all_args, "_$", arg->name, arg->next ? ", " : CORD_EMPTY); CORD wrapper = CORD_all( @@ -4210,11 +4206,11 @@ CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs) } } - CORD qualified_name = raw_name; + CORD qualified_name = function_name; if (env->namespace && env->namespace->parent && env->namespace->name) qualified_name = CORD_all(env->namespace->name, ".", qualified_name); CORD text = CORD_all("func ", qualified_name, "("); - for (arg_ast_t *arg = fndef->args; arg; arg = arg->next) { + for (arg_ast_t *arg = args; arg; arg = arg->next) { text = CORD_cat(text, type_to_cord(get_arg_ast_type(env, arg))); if (arg->next) text = CORD_cat(text, ", "); } @@ -4222,7 +4218,7 @@ CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs) text = CORD_all(text, "->", type_to_cord(ret_t)); text = CORD_all(text, ")"); - if (!fndef->is_inline) { + if (!is_inline) { env->code->function_naming = CORD_all( env->code->function_naming, CORD_asprintf("register_function(%r, Text(\"%s.tm\"), %ld, Text(%r));\n", @@ -4267,7 +4263,16 @@ CORD compile_top_level_code(env_t *env, ast_t *ast) } } case FunctionDef: { - return compile_function(env, ast, &env->code->staticdefs); + CORD name_code = CORD_all(namespace_prefix(env, env->namespace), Match(Match(ast, FunctionDef)->name, Var)->name); + return compile_function(env, name_code, ast, &env->code->staticdefs); + } + case ConvertDef: { + type_t *type = get_function_def_type(env, ast); + const char *name = get_type_name(Match(type, FunctionType)->ret); + if (!name) + code_err(ast, "Conversions are only supported for text, struct, and enum types, not %T", Match(type, FunctionType)->ret); + CORD name_code = CORD_asprintf("%r%s$%ld", namespace_prefix(env, env->namespace), name, get_line_number(ast->file, ast->start)); + return compile_function(env, name_code, ast, &env->code->staticdefs); } case StructDef: { auto def = Match(ast, StructDef); @@ -4510,6 +4515,25 @@ CORD compile_statement_namespace_header(env_t *env, ast_t *ast) name = CORD_asprintf("%r%ld", namespace_prefix(env, env->namespace), get_line_number(ast->file, ast->start)); return CORD_all(ret_type_code, " ", name, arg_signature, ";\n"); } + case ConvertDef: { + auto def = Match(ast, ConvertDef); + + CORD arg_signature = "("; + for (arg_ast_t *arg = def->args; arg; arg = arg->next) { + type_t *arg_type = get_arg_ast_type(env, arg); + arg_signature = CORD_cat(arg_signature, compile_declaration(arg_type, CORD_cat("_$", arg->name))); + if (arg->next) arg_signature = CORD_cat(arg_signature, ", "); + } + arg_signature = CORD_cat(arg_signature, ")"); + + type_t *ret_t = def->ret_type ? parse_type_ast(env, def->ret_type) : Type(VoidType); + CORD ret_type_code = compile_type(ret_t); + const char *name = get_type_name(ret_t); + if (!name) + code_err(ast, "Conversions are only supported for text, struct, and enum types, not %T", ret_t); + CORD name_code = CORD_asprintf("%s$%ld", name, get_line_number(ast->file, ast->start)); + return CORD_all(ret_type_code, " ", name_code, arg_signature, ";\n"); + } default: return CORD_EMPTY; } env_t *ns_env = namespace_env(env, ns_name); diff --git a/docs/langs.md b/docs/langs.md index ed5e349..f087b31 100644 --- a/docs/langs.md +++ b/docs/langs.md @@ -10,7 +10,7 @@ where a different type of string is needed. ```tomo lang HTML: - func HTML(t:Text -> HTML): + convert(t:Text -> HTML): t = t:replace_all({ $/&/ = "&", $/ Sh): + convert(text:Text -> Sh): return Sh.without_escaping("'" ++ text:replace($/'/, "''") ++ "'") func execute(sh:Sh -> Text): @@ -84,3 +84,22 @@ dir := ask("List which dir? ") cmd := $Sh@(ls -l @dir) result := cmd:execute() ``` + +## Conversions + +You can define your own rules for converting between types using the `convert` +keyword. Conversions can be defined either inside of the language's block, +another type's block or at the top level. + +```tomo +lang Sh: + convert(text:Text -> Sh): + return Sh.without_escaping("'" ++ text:replace($/'/, "''") ++ "'") + +struct Foo(x,y:Int): + convert(f:Foo -> Sh): + return Sh.without_escaping("$(f.x),$(f.y)") + +convert(texts:[Text] -> Sh): + return $Sh" ":join([Sh(t) for t in texts]) +``` diff --git a/environment.c b/environment.c index 1c8df4c..b871bff 100644 --- a/environment.c +++ b/environment.c @@ -26,7 +26,7 @@ env_t *new_compilation_unit(CORD libname) env->imports = new(Table_t); if (!TEXT_TYPE) - TEXT_TYPE = Type(TextType, .env=namespace_env(env, "Text")); + TEXT_TYPE = Type(TextType, .lang="Text", .env=namespace_env(env, "Text")); struct { const char *name; @@ -579,6 +579,7 @@ env_t *new_compilation_unit(CORD libname) {"Shell$escape_text_array", "func(texts:[Text] -> Shell)"}, {"Shell$escape_text_array", "func(paths:[Path] -> Shell)"}, {"Int$value_as_text", "func(i:Int -> Shell)"}); + ADD_CONSTRUCTORS("CString", {"Text$as_c_string", "func(text:Text -> CString)"}); ADD_CONSTRUCTORS("Moment", {"Moment$now", "func(-> Moment)"}, {"Moment$new", "func(year,month,day:Int,hour,minute=0,second=0.0,timezone=none:Text -> Moment)"}, @@ -736,7 +737,7 @@ env_t *for_scope(env_t *env, ast_t *ast) } } -static env_t *get_namespace_by_type(env_t *env, type_t *t) +env_t *get_namespace_by_type(env_t *env, type_t *t) { t = value_type(t); switch (t->tag) { @@ -794,23 +795,23 @@ binding_t *get_namespace_binding(env_t *env, ast_t *self, const char *name) return ns_env ? get_binding(ns_env, name) : NULL; } -PUREFUNC binding_t *get_constructor(env_t *env, type_t *t, arg_ast_t *args, type_t *constructed_type) +PUREFUNC binding_t *get_constructor(env_t *env, type_t *t, arg_ast_t *args) { env_t *type_env = get_namespace_by_type(env, t); if (!type_env) return NULL; Array_t constructors = type_env->namespace->constructors; // Prioritize exact matches: - for (int64_t i = 0; i < constructors.length; i++) { + for (int64_t i = constructors.length-1; i >= 0; i--) { binding_t *b = constructors.data + i*constructors.stride; auto fn = Match(b->type, FunctionType); - if (type_eq(fn->ret, constructed_type) && is_valid_call(env, fn->args, args, false)) + if (type_eq(fn->ret, t) && is_valid_call(env, fn->args, args, false)) return b; } // Fall back to promotion: - for (int64_t i = 0; i < constructors.length; i++) { + for (int64_t i = constructors.length-1; i >= 0; i--) { binding_t *b = constructors.data + i*constructors.stride; auto fn = Match(b->type, FunctionType); - if (type_eq(fn->ret, constructed_type) && is_valid_call(env, fn->args, args, true)) + if (type_eq(fn->ret, t) && is_valid_call(env, fn->args, args, true)) return b; } return NULL; diff --git a/environment.h b/environment.h index b5da550..4566d37 100644 --- a/environment.h +++ b/environment.h @@ -58,6 +58,7 @@ typedef struct { env_t *new_compilation_unit(CORD libname); env_t *load_module_env(env_t *env, ast_t *ast); CORD namespace_prefix(env_t *env, namespace_t *ns); +env_t *get_namespace_by_type(env_t *env, type_t *t); env_t *namespace_scope(env_t *env); env_t *fresh_scope(env_t *env); env_t *for_scope(env_t *env, ast_t *ast); @@ -65,7 +66,7 @@ env_t *namespace_env(env_t *env, const char *namespace_name); __attribute__((format(printf, 4, 5))) _Noreturn void compiler_err(file_t *f, const char *start, const char *end, const char *fmt, ...); binding_t *get_binding(env_t *env, const char *name); -binding_t *get_constructor(env_t *env, type_t *t, arg_ast_t *args, type_t *constructed_type); +binding_t *get_constructor(env_t *env, type_t *t, arg_ast_t *args); void set_binding(env_t *env, const char *name, type_t *type, CORD code); binding_t *get_namespace_binding(env_t *env, ast_t *self, const char *name); #define code_err(ast, ...) compiler_err((ast)->file, (ast)->start, (ast)->end, __VA_ARGS__) diff --git a/parse.c b/parse.c index a3778d9..ed8d041 100644 --- a/parse.c +++ b/parse.c @@ -94,6 +94,7 @@ static PARSER(parse_array); static PARSER(parse_assignment); static PARSER(parse_block); static PARSER(parse_bool); +static PARSER(parse_convert_def); static PARSER(parse_declaration); static PARSER(parse_defer); static PARSER(parse_do); @@ -2035,6 +2036,7 @@ PARSER(parse_namespace) { ||(stmt=optional(ctx, &pos, parse_enum_def)) ||(stmt=optional(ctx, &pos, parse_lang_def)) ||(stmt=optional(ctx, &pos, parse_func_def)) + ||(stmt=optional(ctx, &pos, parse_convert_def)) ||(stmt=optional(ctx, &pos, parse_use)) ||(stmt=optional(ctx, &pos, parse_extern)) ||(stmt=optional(ctx, &pos, parse_inline_c)) @@ -2070,6 +2072,7 @@ PARSER(parse_file_body) { ||(stmt=optional(ctx, &pos, parse_enum_def)) ||(stmt=optional(ctx, &pos, parse_lang_def)) ||(stmt=optional(ctx, &pos, parse_func_def)) + ||(stmt=optional(ctx, &pos, parse_convert_def)) ||(stmt=optional(ctx, &pos, parse_use)) ||(stmt=optional(ctx, &pos, parse_extern)) ||(stmt=optional(ctx, &pos, parse_inline_c)) @@ -2321,6 +2324,42 @@ PARSER(parse_func_def) { .is_inline=is_inline); } +PARSER(parse_convert_def) { + const char *start = pos; + if (!match_word(&pos, "convert")) return NULL; + + spaces(&pos); + + if (!match(&pos, "(")) return NULL; + + arg_ast_t *args = parse_args(ctx, &pos); + spaces(&pos); + type_ast_t *ret_type = match(&pos, "->") ? optional(ctx, &pos, parse_type) : NULL; + whitespace(&pos); + bool is_inline = false; + ast_t *cache_ast = NULL; + for (bool specials = match(&pos, ";"); specials; specials = match_separator(&pos)) { + const char *flag_start = pos; + if (match_word(&pos, "inline")) { + is_inline = true; + } else if (match_word(&pos, "cached")) { + if (!cache_ast) cache_ast = NewAST(ctx->file, pos, pos, Int, .str="-1"); + } else if (match_word(&pos, "cache_size")) { + whitespace(&pos); + if (!match(&pos, "=")) + parser_err(ctx, flag_start, pos, "I expected a value for 'cache_size'"); + whitespace(&pos); + cache_ast = expect(ctx, start, &pos, parse_expr, "I expected a maximum size for the cache"); + } + } + expect_closing(ctx, &pos, ")", "I wasn't able to parse the rest of this function definition"); + + ast_t *body = expect(ctx, start, &pos, parse_block, + "This function needs a body block"); + return NewAST(ctx->file, start, pos, ConvertDef, + .args=args, .ret_type=ret_type, .body=body, .cache=cache_ast, .is_inline=is_inline); +} + PARSER(parse_extern) { const char *start = pos; if (!match_word(&pos, "extern")) return NULL; diff --git a/test/lang.tm b/test/lang.tm index 777a05b..29936a6 100644 --- a/test/lang.tm +++ b/test/lang.tm @@ -1,6 +1,6 @@ lang HTML: HEADER := $HTML"" - func HTML(t:Text->HTML): + convert(t:Text->HTML): t = t:replace_all({ $/&/="&", $/HTML): + convert(i:Int->HTML): return HTML.without_escaping("$i") func paragraph(content:HTML->HTML): return $HTML"

$content

" struct Bold(text:Text): - func HTML(b:Bold -> HTML): + convert(b:Bold -> HTML): return $HTML"$(b.text)" func main(): diff --git a/typecheck.c b/typecheck.c index 0530ff7..6038d70 100644 --- a/typecheck.c +++ b/typecheck.c @@ -315,6 +315,20 @@ void bind_statement(env_t *env, ast_t *statement) set_binding(env, name, type, code); break; } + case ConvertDef: { + type_t *type = get_function_def_type(env, statement); + type_t *ret_t = Match(type, FunctionType)->ret; + const char *name = get_type_name(ret_t); + if (!name) + code_err(statement, "Conversions are only supported for text, struct, and enum types, not %T", ret_t); + + CORD code = CORD_asprintf("%r%r$%ld", namespace_prefix(env, env->namespace), name, + get_line_number(statement->file, statement->start)); + binding_t binding = {.type=type, .code=code}; + env_t *type_ns = get_namespace_by_type(env, ret_t); + Array$insert(&type_ns->namespace->constructors, &binding, I(0), sizeof(binding)); + break; + } case StructDef: { auto def = Match(statement, StructDef); env_t *ns_env = namespace_env(env, def->name); @@ -460,17 +474,18 @@ void bind_statement(env_t *env, ast_t *statement) type_t *get_function_def_type(env_t *env, ast_t *ast) { - auto fn = Match(ast, FunctionDef); + arg_ast_t *arg_asts = ast->tag == FunctionDef ? Match(ast, FunctionDef)->args : Match(ast, ConvertDef)->args; + type_ast_t *ret_type = ast->tag == FunctionDef ? Match(ast, FunctionDef)->ret_type : Match(ast, ConvertDef)->ret_type; arg_t *args = NULL; env_t *scope = fresh_scope(env); - for (arg_ast_t *arg = fn->args; arg; arg = arg->next) { + for (arg_ast_t *arg = arg_asts; arg; arg = arg->next) { type_t *t = arg->type ? parse_type_ast(env, arg->type) : get_type(env, arg->value); args = new(arg_t, .name=arg->name, .type=t, .default_val=arg->value, .next=args); set_binding(scope, arg->name, t, CORD_EMPTY); } REVERSE_LIST(args); - type_t *ret = fn->ret_type ? parse_type_ast(scope, fn->ret_type) : Type(VoidType); + type_t *ret = ret_type ? parse_type_ast(scope, ret_type) : Type(VoidType); if (has_stack_memory(ret)) code_err(ast, "Functions can't return stack references because the reference may outlive its stack frame."); return Type(FunctionType, .args=args, .ret=ret); @@ -815,7 +830,7 @@ type_t *get_type(env_t *env, ast_t *ast) if (fn_type_t->tag == TypeInfoType) { type_t *t = Match(fn_type_t, TypeInfoType)->type; - binding_t *constructor = get_constructor(env, t, call->args, t); + binding_t *constructor = get_constructor(env, t, call->args); if (constructor) return t; else if (t->tag == StructType || t->tag == IntType || t->tag == BigIntType || t->tag == NumType @@ -914,7 +929,7 @@ type_t *get_type(env_t *env, ast_t *ast) // Early out if the type is knowable without any context from the block: switch (last->ast->tag) { - case UpdateAssign: case Assign: case Declare: case FunctionDef: case StructDef: case EnumDef: case LangDef: + case UpdateAssign: case Assign: case Declare: case FunctionDef: case ConvertDef: case StructDef: case EnumDef: case LangDef: return Type(VoidType); default: break; } @@ -1236,7 +1251,7 @@ type_t *get_type(env_t *env, ast_t *ast) return Type(ClosureType, Type(FunctionType, .args=args, .ret=ret)); } - case FunctionDef: case StructDef: case EnumDef: case LangDef: { + case FunctionDef: case ConvertDef: case StructDef: case EnumDef: case LangDef: { return Type(VoidType); } @@ -1393,7 +1408,7 @@ type_t *get_type(env_t *env, ast_t *ast) PUREFUNC bool is_discardable(env_t *env, ast_t *ast) { switch (ast->tag) { - case UpdateAssign: case Assign: case Declare: case FunctionDef: case StructDef: case EnumDef: + case UpdateAssign: case Assign: case Declare: case FunctionDef: case ConvertDef: case StructDef: case EnumDef: case LangDef: case Use: return true; default: break; diff --git a/types.c b/types.c index c8cb4e3..c943eb8 100644 --- a/types.c +++ b/types.c @@ -105,6 +105,16 @@ CORD type_to_cord(type_t *t) { } } +PUREFUNC const char *get_type_name(type_t *t) +{ + switch (t->tag) { + case TextType: return Match(t, TextType)->lang; + case StructType: return Match(t, StructType)->name; + case EnumType: return Match(t, EnumType)->name; + default: return NULL; + } +} + int printf_pointer_size(const struct printf_info *info, size_t n, int argtypes[n], int sizes[n]) { if (n < 1) return -1; diff --git a/types.h b/types.h index 882e7e2..1f39f53 100644 --- a/types.h +++ b/types.h @@ -137,6 +137,7 @@ struct type_s { int printf_pointer_size(const struct printf_info *info, size_t n, int argtypes[n], int size[n]); int printf_type(FILE *stream, const struct printf_info *info, const void *const args[]); CORD type_to_cord(type_t *t); +const char *get_type_name(type_t *t); PUREFUNC bool type_eq(type_t *a, type_t *b); PUREFUNC bool type_is_a(type_t *t, type_t *req); type_t *type_or_type(type_t *a, type_t *b);