diff options
| author | Bruce Hill <bruce@bruce-hill.com> | 2025-03-01 16:04:14 -0500 |
|---|---|---|
| committer | Bruce Hill <bruce@bruce-hill.com> | 2025-03-01 16:04:14 -0500 |
| commit | fab0083129c134a49c77a4f7e3b333f1e0f14d55 (patch) | |
| tree | f4d493712a7b9ebad9ab4136bd1fa25ec4e933d6 | |
| parent | 90548ebb33ea69c6c9e0962d8bab6d3ec72ac90f (diff) | |
Support post-hoc definitions of escaping rules for DSLs
| -rw-r--r-- | compile.c | 62 | ||||
| -rw-r--r-- | environment.c | 98 | ||||
| -rw-r--r-- | environment.h | 2 | ||||
| -rw-r--r-- | test/lang.tm | 9 | ||||
| -rw-r--r-- | typecheck.c | 12 |
5 files changed, 87 insertions, 96 deletions
@@ -2680,23 +2680,25 @@ CORD compile(env_t *env, ast_t *ast) type_t *chunk_t = get_type(env, chunk->ast); if (chunk->ast->tag == TextLiteral) { chunk_code = compile(env, chunk->ast); - } else if (chunk_t->tag == TextType && streq(Match(chunk_t, TextType)->lang, lang)) { - binding_t *esc = get_constructor(env, text_t, new(arg_ast_t, .value=chunk->ast)); - if (esc) { - arg_t *arg_spec = Match(esc->type, FunctionType)->args; + } else if (chunk_t->tag == TextType && streq(Match(chunk_t, TextType)->lang, lang)) { // Interp is same type as text literal + binding_t *constructor = get_constructor(env, text_t, new(arg_ast_t, .value=chunk->ast), text_t); + if (constructor) { + arg_t *arg_spec = Match(constructor->type, FunctionType)->args; arg_ast_t *args = new(arg_ast_t, .value=chunk->ast); - chunk_code = CORD_all(esc->code, "(", compile_arguments(env, ast, arg_spec, args), ")"); + chunk_code = CORD_all(constructor->code, "(", compile_arguments(env, ast, arg_spec, args), ")"); } else { chunk_code = compile(env, chunk->ast); } - } else if (lang) { - binding_t *esc = get_constructor(env, text_t, new(arg_ast_t, .value=chunk->ast)); - if (!esc) + } else if (lang) { // Interp is different type from text literal (which is a DSL) + binding_t *constructor = get_constructor(env, text_t, new(arg_ast_t, .value=chunk->ast), text_t); + if (!constructor) + constructor = get_constructor(env, chunk_t, new(arg_ast_t, .value=chunk->ast), text_t); + if (!constructor) code_err(chunk->ast, "I don't know how to convert %T to %T", chunk_t, text_t); - arg_t *arg_spec = Match(esc->type, FunctionType)->args; + arg_t *arg_spec = Match(constructor->type, FunctionType)->args; arg_ast_t *args = new(arg_ast_t, .value=chunk->ast); - chunk_code = CORD_all(esc->code, "(", compile_arguments(env, ast, arg_spec, args), ")"); + chunk_code = CORD_all(constructor->code, "(", compile_arguments(env, ast, arg_spec, args), ")"); } else { chunk_code = compile_string(env, chunk->ast, "no"); } @@ -3369,7 +3371,7 @@ CORD compile(env_t *env, ast_t *ast) } else if (fn_t->tag == TypeInfoType) { type_t *t = Match(fn_t, TypeInfoType)->type; - binding_t *constructor = get_constructor(env, t, call->args); + binding_t *constructor = get_constructor(env, t, call->args, t); if (constructor) { arg_t *arg_spec = Match(constructor->type, FunctionType)->args; return CORD_all(constructor->code, "(", compile_arguments(env, ast, arg_spec, call->args), ")"); @@ -4076,11 +4078,15 @@ CORD compile_cli_arg_call(env_t *env, CORD fn_name, type_t *fn_type) CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs) { auto fndef = Match(ast, FunctionDef); - bool is_private = Match(fndef->name, Var)->name[0] == '_'; - CORD name = compile(env, fndef->name); - if (env->namespace && env->namespace->parent && env->namespace->name && streq(Match(fndef->name, Var)->name, env->namespace->name)) - name = CORD_asprintf("%r$%ld", name, get_line_number(ast->file, ast->start)); + const char *raw_name = Match(fndef->name, Var)->name; + bool is_private = raw_name[0] == '_'; + CORD name_code = CORD_all(namespace_prefix(env, env->namespace), raw_name); + binding_t *clobbered = get_binding(env, raw_name); type_t *ret_t = fndef->ret_type ? parse_type_ast(env, fndef->ret_type) : Type(VoidType); + // Check for a constructor: + if (clobbered && clobbered->type->tag == TypeInfoType && type_eq(ret_t, Match(clobbered->type, TypeInfoType)->type)) { + name_code = CORD_asprintf("%r$%ld", name_code, get_line_number(ast->file, ast->start)); + } CORD arg_signature = "("; Table_t used_names = {}; @@ -4097,13 +4103,13 @@ CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs) CORD ret_type_code = compile_type(ret_t); if (is_private) - *staticdefs = CORD_all(*staticdefs, "static ", ret_type_code, " ", name, arg_signature, ";\n"); + *staticdefs = CORD_all(*staticdefs, "static ", ret_type_code, " ", name_code, arg_signature, ";\n"); CORD code; if (fndef->cache) { - code = CORD_all("static ", ret_type_code, " ", name, "$uncached", arg_signature); + code = CORD_all("static ", ret_type_code, " ", name_code, "$uncached", arg_signature); } else { - code = CORD_all(ret_type_code, " ", name, arg_signature); + code = CORD_all(ret_type_code, " ", name_code, arg_signature); if (fndef->is_inline) code = CORD_cat("INLINE ", code); if (!is_private) @@ -4125,7 +4131,7 @@ CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs) code_err(ast, "This function can reach the end without returning a %T value!", ret_t); CORD body = compile_statement(body_scope, fndef->body); - if (streq(Match(fndef->name, Var)->name, "main")) + if (streq(raw_name, "main")) body = CORD_all("_$", env->namespace->name, "$$initialize();\n", body); if (CORD_fetch(body, 0) != '{') body = CORD_asprintf("{\n%r\n}", body); @@ -4134,11 +4140,11 @@ CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs) if (fndef->cache && fndef->args == NULL) { // no-args cache just uses a static var CORD wrapper = CORD_all( - is_private ? CORD_EMPTY : "public ", ret_type_code, " ", name, "(void) {\n" + is_private ? CORD_EMPTY : "public ", ret_type_code, " ", name_code, "(void) {\n" "static ", compile_declaration(ret_t, "cached_result"), ";\n", "static bool initialized = false;\n", "if (!initialized) {\n" - "\tcached_result = ", name, "$uncached();\n", + "\tcached_result = ", name_code, "$uncached();\n", "\tinitialized = true;\n", "}\n", "return cached_result;\n" @@ -4157,12 +4163,12 @@ CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs) // Single-argument functions have simplified caching logic type_t *arg_type = get_arg_ast_type(env, fndef->args); CORD wrapper = CORD_all( - is_private ? CORD_EMPTY : "public ", ret_type_code, " ", name, arg_signature, "{\n" + is_private ? CORD_EMPTY : "public ", ret_type_code, " ", name_code, arg_signature, "{\n" "static Table_t cache = {};\n", "const TypeInfo_t *table_type = Table$info(", compile_type_info(env, arg_type), ", ", compile_type_info(env, ret_t), ");\n", compile_declaration(Type(PointerType, .pointed=ret_t), "cached"), " = Table$get_raw(cache, &_$", fndef->args->name, ", table_type);\n" "if (cached) return *cached;\n", - compile_declaration(ret_t, "ret"), " = ", name, "$uncached(_$", fndef->args->name, ");\n", + compile_declaration(ret_t, "ret"), " = ", name_code, "$uncached(_$", fndef->args->name, ");\n", pop_code, "Table$set(&cache, &_$", fndef->args->name, ", &ret, table_type);\n" "return ret;\n" @@ -4177,7 +4183,7 @@ CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs) type_t *t = Type(StructType, .name=heap_strf("%s$args", fndef->name), .fields=fields, .env=env); int64_t num_fields = used_names.entries.length; - const char *short_name = Match(fndef->name, Var)->name; + const char *short_name = raw_name; const char *metamethods = is_packed_data(t) ? "PackedData$metamethods" : "Struct$metamethods"; CORD args_typeinfo = CORD_asprintf("((TypeInfo_t[1]){{.size=%zu, .align=%zu, .metamethods=%s, " ".tag=StructInfo, .StructInfo.name=\"%s\", " @@ -4198,13 +4204,13 @@ CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs) all_args = CORD_all(all_args, "_$", arg->name, arg->next ? ", " : CORD_EMPTY); CORD wrapper = CORD_all( - is_private ? CORD_EMPTY : "public ", ret_type_code, " ", name, arg_signature, "{\n" + is_private ? CORD_EMPTY : "public ", ret_type_code, " ", name_code, arg_signature, "{\n" "static Table_t cache = {};\n", args_type, " args = {", all_args, "};\n" "const TypeInfo_t *table_type = Table$info(", args_typeinfo, ", ", compile_type_info(env, ret_t), ");\n", compile_declaration(Type(PointerType, .pointed=ret_t), "cached"), " = Table$get_raw(cache, &args, table_type);\n" "if (cached) return *cached;\n", - compile_declaration(ret_t, "ret"), " = ", name, "$uncached(", all_args, ");\n", + compile_declaration(ret_t, "ret"), " = ", name_code, "$uncached(", all_args, ");\n", pop_code, "Table$set(&cache, &args, &ret, table_type);\n" "return ret;\n" @@ -4213,7 +4219,7 @@ CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs) } } - CORD qualified_name = Match(fndef->name, Var)->name; + CORD qualified_name = raw_name; if (env->namespace && env->namespace->parent && env->namespace->name) qualified_name = CORD_all(env->namespace->name, ".", qualified_name); CORD text = CORD_all("func ", qualified_name, "("); @@ -4229,7 +4235,7 @@ CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs) env->code->function_naming = CORD_all( env->code->function_naming, CORD_asprintf("register_function(%r, Text(\"%s.tm\"), %ld, Text(%r));\n", - name, file_base_name(ast->file->filename), get_line_number(ast->file, ast->start), CORD_quoted(text))); + name_code, file_base_name(ast->file->filename), get_line_number(ast->file, ast->start), CORD_quoted(text))); } return definition; } diff --git a/environment.c b/environment.c index 2e1c4add..393c0c30 100644 --- a/environment.c +++ b/environment.c @@ -636,6 +636,36 @@ env_t *for_scope(env_t *env, ast_t *ast) } } +static env_t *get_namespace_by_type(env_t *env, type_t *t) +{ + t = value_type(t); + switch (t->tag) { + case ArrayType: return NULL; + case TableType: return NULL; + case CStringType: case MomentType: + case BoolType: case IntType: case BigIntType: case NumType: case ByteType: { + binding_t *b = get_binding(env, CORD_to_const_char_star(type_to_cord(t))); + assert(b); + return Match(b->type, TypeInfoType)->env; + } + case TextType: return Match(t, TextType)->env; + case StructType: { + auto struct_ = Match(t, StructType); + return struct_->env; + } + case EnumType: { + auto enum_ = Match(t, EnumType); + return enum_->env; + } + case TypeInfoType: { + auto info = Match(t, TypeInfoType); + return info->env; + } + default: break; + } + return NULL; +} + env_t *namespace_env(env_t *env, const char *namespace_name) { binding_t *b = get_binding(env, namespace_name); @@ -660,82 +690,26 @@ binding_t *get_namespace_binding(env_t *env, ast_t *self, const char *name) type_t *self_type = get_type(env, self); if (!self_type) code_err(self, "I couldn't get this type"); - type_t *cls_type = value_type(self_type); - switch (cls_type->tag) { - case ArrayType: return NULL; - case TableType: return NULL; - case CStringType: case MomentType: - case BoolType: case IntType: case BigIntType: case NumType: case ByteType: { - binding_t *b = get_binding(env, CORD_to_const_char_star(type_to_cord(cls_type))); - assert(b); - return get_binding(Match(b->type, TypeInfoType)->env, name); - } - case TextType: { - auto text = Match(cls_type, TextType); - env_t *text_env = text->env ? text->env : namespace_env(env, text->lang ? text->lang : "Text"); - assert(text_env); - return get_binding(text_env, name); - } - case StructType: { - auto struct_ = Match(cls_type, StructType); - return struct_->env ? get_binding(struct_->env, name) : NULL; - } - case EnumType: { - auto enum_ = Match(cls_type, EnumType); - return enum_->env ? get_binding(enum_->env, name) : NULL; - } - case TypeInfoType: { - auto info = Match(cls_type, TypeInfoType); - return info->env ? get_binding(info->env, name) : NULL; - } - default: break; - } - return NULL; + env_t *ns_env = get_namespace_by_type(env, self_type); + return ns_env ? get_binding(ns_env, name) : NULL; } -PUREFUNC binding_t *get_constructor(env_t *env, type_t *t, arg_ast_t *args) +PUREFUNC binding_t *get_constructor(env_t *env, type_t *t, arg_ast_t *args, type_t *constructed_type) { - const char *type_name; - t = value_type(t); - switch (t->tag) { - case TextType: { - type_name = Match(t, TextType)->lang; - if (type_name == NULL) type_name = "Text"; - break; - } - case StructType: { - type_name = Match(t, StructType)->name; - break; - } - case EnumType: { - type_name = Match(t, EnumType)->name; - break; - } - default: { - type_name = NULL; - break; - } - } - - if (!type_name) - return NULL; - - binding_t *typeinfo = get_binding(env, type_name); - assert(typeinfo && typeinfo->type->tag == TypeInfoType); - env_t *type_env = Match(typeinfo->type, TypeInfoType)->env; + env_t *type_env = get_namespace_by_type(env, t); Array_t constructors = type_env->namespace->constructors; // Prioritize exact matches: for (int64_t i = 0; i < constructors.length; i++) { binding_t *b = constructors.data + i*constructors.stride; auto fn = Match(b->type, FunctionType); - if (is_valid_call(env, fn->args, args, false)) + if (type_eq(fn->ret, constructed_type) && is_valid_call(env, fn->args, args, false)) return b; } // Fall back to promotion: for (int64_t i = 0; i < constructors.length; i++) { binding_t *b = constructors.data + i*constructors.stride; auto fn = Match(b->type, FunctionType); - if (is_valid_call(env, fn->args, args, true)) + if (type_eq(fn->ret, constructed_type) && is_valid_call(env, fn->args, args, true)) return b; } return NULL; diff --git a/environment.h b/environment.h index c3d897b8..b5da5506 100644 --- a/environment.h +++ b/environment.h @@ -65,7 +65,7 @@ env_t *namespace_env(env_t *env, const char *namespace_name); __attribute__((format(printf, 4, 5))) _Noreturn void compiler_err(file_t *f, const char *start, const char *end, const char *fmt, ...); binding_t *get_binding(env_t *env, const char *name); -binding_t *get_constructor(env_t *env, type_t *t, arg_ast_t *args); +binding_t *get_constructor(env_t *env, type_t *t, arg_ast_t *args, type_t *constructed_type); void set_binding(env_t *env, const char *name, type_t *type, CORD code); binding_t *get_namespace_binding(env_t *env, ast_t *self, const char *name); #define code_err(ast, ...) compiler_err((ast)->file, (ast)->start, (ast)->end, __VA_ARGS__) diff --git a/test/lang.tm b/test/lang.tm index 77fd7958..5d8c619e 100644 --- a/test/lang.tm +++ b/test/lang.tm @@ -17,6 +17,10 @@ lang HTML: func paragraph(content:HTML->HTML): return $HTML"<p>$content</p>" +struct Bold(text:Text): + func HTML(b:Bold -> HTML): + return $HTML"<b>$(b.text)</b>" + func main(): >> HTML.HEADER = $HTML"<!DOCTYPE HTML>" @@ -44,3 +48,8 @@ func main(): >> Text(html) = '$HTML"Hello I <3 hax!"' + + >> b := Bold("Some <text> with junk") + >> $HTML"Your text: $b" + = $HTML"Your text: <b>Some <text> with junk</b>" + diff --git a/typecheck.c b/typecheck.c index 26a8ff9f..91343179 100644 --- a/typecheck.c +++ b/typecheck.c @@ -299,15 +299,17 @@ void bind_statement(env_t *env, ast_t *statement) auto def = Match(statement, FunctionDef); const char *name = Match(def->name, Var)->name; type_t *type = get_function_def_type(env, statement); - if (env->namespace && env->namespace->parent && env->namespace->name && streq(name, env->namespace->name)) { - CORD code = CORD_asprintf("%r%ld", namespace_prefix(env, env->namespace), get_line_number(statement->file, statement->start)); + binding_t *clobber = get_binding(env, name); + if (clobber && clobber->type->tag == TypeInfoType && type_eq(Match(clobber->type, TypeInfoType)->type, Match(type, FunctionType)->ret)) { + CORD code = CORD_asprintf("%r%r$%ld", namespace_prefix(env, env->namespace), name, + get_line_number(statement->file, statement->start)); binding_t binding = {.type=type, .code=code}; Array$insert(&env->namespace->constructors, &binding, I(0), sizeof(binding)); break; } - if (get_binding(env, name)) - code_err(def->name, "A %T called '%s' has already been defined", get_binding(env, name)->type, name); + if (clobber) + code_err(def->name, "A %T called '%s' has already been defined", clobber->type, name); CORD code = CORD_all(namespace_prefix(env, env->namespace), name); set_binding(env, name, type, code); break; @@ -792,7 +794,7 @@ type_t *get_type(env_t *env, ast_t *ast) if (fn_type_t->tag == TypeInfoType) { type_t *t = Match(fn_type_t, TypeInfoType)->type; - binding_t *constructor = get_constructor(env, t, call->args); + binding_t *constructor = get_constructor(env, t, call->args, t); if (constructor) return t; else if (t->tag == StructType || t->tag == IntType || t->tag == BigIntType || t->tag == NumType |
