aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2025-03-01 16:04:14 -0500
committerBruce Hill <bruce@bruce-hill.com>2025-03-01 16:04:14 -0500
commitfab0083129c134a49c77a4f7e3b333f1e0f14d55 (patch)
treef4d493712a7b9ebad9ab4136bd1fa25ec4e933d6
parent90548ebb33ea69c6c9e0962d8bab6d3ec72ac90f (diff)
Support post-hoc definitions of escaping rules for DSLs
-rw-r--r--compile.c62
-rw-r--r--environment.c98
-rw-r--r--environment.h2
-rw-r--r--test/lang.tm9
-rw-r--r--typecheck.c12
5 files changed, 87 insertions, 96 deletions
diff --git a/compile.c b/compile.c
index 24c63bd1..7331483b 100644
--- a/compile.c
+++ b/compile.c
@@ -2680,23 +2680,25 @@ CORD compile(env_t *env, ast_t *ast)
type_t *chunk_t = get_type(env, chunk->ast);
if (chunk->ast->tag == TextLiteral) {
chunk_code = compile(env, chunk->ast);
- } else if (chunk_t->tag == TextType && streq(Match(chunk_t, TextType)->lang, lang)) {
- binding_t *esc = get_constructor(env, text_t, new(arg_ast_t, .value=chunk->ast));
- if (esc) {
- arg_t *arg_spec = Match(esc->type, FunctionType)->args;
+ } else if (chunk_t->tag == TextType && streq(Match(chunk_t, TextType)->lang, lang)) { // Interp is same type as text literal
+ binding_t *constructor = get_constructor(env, text_t, new(arg_ast_t, .value=chunk->ast), text_t);
+ if (constructor) {
+ arg_t *arg_spec = Match(constructor->type, FunctionType)->args;
arg_ast_t *args = new(arg_ast_t, .value=chunk->ast);
- chunk_code = CORD_all(esc->code, "(", compile_arguments(env, ast, arg_spec, args), ")");
+ chunk_code = CORD_all(constructor->code, "(", compile_arguments(env, ast, arg_spec, args), ")");
} else {
chunk_code = compile(env, chunk->ast);
}
- } else if (lang) {
- binding_t *esc = get_constructor(env, text_t, new(arg_ast_t, .value=chunk->ast));
- if (!esc)
+ } else if (lang) { // Interp is different type from text literal (which is a DSL)
+ binding_t *constructor = get_constructor(env, text_t, new(arg_ast_t, .value=chunk->ast), text_t);
+ if (!constructor)
+ constructor = get_constructor(env, chunk_t, new(arg_ast_t, .value=chunk->ast), text_t);
+ if (!constructor)
code_err(chunk->ast, "I don't know how to convert %T to %T", chunk_t, text_t);
- arg_t *arg_spec = Match(esc->type, FunctionType)->args;
+ arg_t *arg_spec = Match(constructor->type, FunctionType)->args;
arg_ast_t *args = new(arg_ast_t, .value=chunk->ast);
- chunk_code = CORD_all(esc->code, "(", compile_arguments(env, ast, arg_spec, args), ")");
+ chunk_code = CORD_all(constructor->code, "(", compile_arguments(env, ast, arg_spec, args), ")");
} else {
chunk_code = compile_string(env, chunk->ast, "no");
}
@@ -3369,7 +3371,7 @@ CORD compile(env_t *env, ast_t *ast)
} else if (fn_t->tag == TypeInfoType) {
type_t *t = Match(fn_t, TypeInfoType)->type;
- binding_t *constructor = get_constructor(env, t, call->args);
+ binding_t *constructor = get_constructor(env, t, call->args, t);
if (constructor) {
arg_t *arg_spec = Match(constructor->type, FunctionType)->args;
return CORD_all(constructor->code, "(", compile_arguments(env, ast, arg_spec, call->args), ")");
@@ -4076,11 +4078,15 @@ CORD compile_cli_arg_call(env_t *env, CORD fn_name, type_t *fn_type)
CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs)
{
auto fndef = Match(ast, FunctionDef);
- bool is_private = Match(fndef->name, Var)->name[0] == '_';
- CORD name = compile(env, fndef->name);
- if (env->namespace && env->namespace->parent && env->namespace->name && streq(Match(fndef->name, Var)->name, env->namespace->name))
- name = CORD_asprintf("%r$%ld", name, get_line_number(ast->file, ast->start));
+ const char *raw_name = Match(fndef->name, Var)->name;
+ bool is_private = raw_name[0] == '_';
+ CORD name_code = CORD_all(namespace_prefix(env, env->namespace), raw_name);
+ binding_t *clobbered = get_binding(env, raw_name);
type_t *ret_t = fndef->ret_type ? parse_type_ast(env, fndef->ret_type) : Type(VoidType);
+ // Check for a constructor:
+ if (clobbered && clobbered->type->tag == TypeInfoType && type_eq(ret_t, Match(clobbered->type, TypeInfoType)->type)) {
+ name_code = CORD_asprintf("%r$%ld", name_code, get_line_number(ast->file, ast->start));
+ }
CORD arg_signature = "(";
Table_t used_names = {};
@@ -4097,13 +4103,13 @@ CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs)
CORD ret_type_code = compile_type(ret_t);
if (is_private)
- *staticdefs = CORD_all(*staticdefs, "static ", ret_type_code, " ", name, arg_signature, ";\n");
+ *staticdefs = CORD_all(*staticdefs, "static ", ret_type_code, " ", name_code, arg_signature, ";\n");
CORD code;
if (fndef->cache) {
- code = CORD_all("static ", ret_type_code, " ", name, "$uncached", arg_signature);
+ code = CORD_all("static ", ret_type_code, " ", name_code, "$uncached", arg_signature);
} else {
- code = CORD_all(ret_type_code, " ", name, arg_signature);
+ code = CORD_all(ret_type_code, " ", name_code, arg_signature);
if (fndef->is_inline)
code = CORD_cat("INLINE ", code);
if (!is_private)
@@ -4125,7 +4131,7 @@ CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs)
code_err(ast, "This function can reach the end without returning a %T value!", ret_t);
CORD body = compile_statement(body_scope, fndef->body);
- if (streq(Match(fndef->name, Var)->name, "main"))
+ if (streq(raw_name, "main"))
body = CORD_all("_$", env->namespace->name, "$$initialize();\n", body);
if (CORD_fetch(body, 0) != '{')
body = CORD_asprintf("{\n%r\n}", body);
@@ -4134,11 +4140,11 @@ CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs)
if (fndef->cache && fndef->args == NULL) { // no-args cache just uses a static var
CORD wrapper = CORD_all(
- is_private ? CORD_EMPTY : "public ", ret_type_code, " ", name, "(void) {\n"
+ is_private ? CORD_EMPTY : "public ", ret_type_code, " ", name_code, "(void) {\n"
"static ", compile_declaration(ret_t, "cached_result"), ";\n",
"static bool initialized = false;\n",
"if (!initialized) {\n"
- "\tcached_result = ", name, "$uncached();\n",
+ "\tcached_result = ", name_code, "$uncached();\n",
"\tinitialized = true;\n",
"}\n",
"return cached_result;\n"
@@ -4157,12 +4163,12 @@ CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs)
// Single-argument functions have simplified caching logic
type_t *arg_type = get_arg_ast_type(env, fndef->args);
CORD wrapper = CORD_all(
- is_private ? CORD_EMPTY : "public ", ret_type_code, " ", name, arg_signature, "{\n"
+ is_private ? CORD_EMPTY : "public ", ret_type_code, " ", name_code, arg_signature, "{\n"
"static Table_t cache = {};\n",
"const TypeInfo_t *table_type = Table$info(", compile_type_info(env, arg_type), ", ", compile_type_info(env, ret_t), ");\n",
compile_declaration(Type(PointerType, .pointed=ret_t), "cached"), " = Table$get_raw(cache, &_$", fndef->args->name, ", table_type);\n"
"if (cached) return *cached;\n",
- compile_declaration(ret_t, "ret"), " = ", name, "$uncached(_$", fndef->args->name, ");\n",
+ compile_declaration(ret_t, "ret"), " = ", name_code, "$uncached(_$", fndef->args->name, ");\n",
pop_code,
"Table$set(&cache, &_$", fndef->args->name, ", &ret, table_type);\n"
"return ret;\n"
@@ -4177,7 +4183,7 @@ CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs)
type_t *t = Type(StructType, .name=heap_strf("%s$args", fndef->name), .fields=fields, .env=env);
int64_t num_fields = used_names.entries.length;
- const char *short_name = Match(fndef->name, Var)->name;
+ const char *short_name = raw_name;
const char *metamethods = is_packed_data(t) ? "PackedData$metamethods" : "Struct$metamethods";
CORD args_typeinfo = CORD_asprintf("((TypeInfo_t[1]){{.size=%zu, .align=%zu, .metamethods=%s, "
".tag=StructInfo, .StructInfo.name=\"%s\", "
@@ -4198,13 +4204,13 @@ CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs)
all_args = CORD_all(all_args, "_$", arg->name, arg->next ? ", " : CORD_EMPTY);
CORD wrapper = CORD_all(
- is_private ? CORD_EMPTY : "public ", ret_type_code, " ", name, arg_signature, "{\n"
+ is_private ? CORD_EMPTY : "public ", ret_type_code, " ", name_code, arg_signature, "{\n"
"static Table_t cache = {};\n",
args_type, " args = {", all_args, "};\n"
"const TypeInfo_t *table_type = Table$info(", args_typeinfo, ", ", compile_type_info(env, ret_t), ");\n",
compile_declaration(Type(PointerType, .pointed=ret_t), "cached"), " = Table$get_raw(cache, &args, table_type);\n"
"if (cached) return *cached;\n",
- compile_declaration(ret_t, "ret"), " = ", name, "$uncached(", all_args, ");\n",
+ compile_declaration(ret_t, "ret"), " = ", name_code, "$uncached(", all_args, ");\n",
pop_code,
"Table$set(&cache, &args, &ret, table_type);\n"
"return ret;\n"
@@ -4213,7 +4219,7 @@ CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs)
}
}
- CORD qualified_name = Match(fndef->name, Var)->name;
+ CORD qualified_name = raw_name;
if (env->namespace && env->namespace->parent && env->namespace->name)
qualified_name = CORD_all(env->namespace->name, ".", qualified_name);
CORD text = CORD_all("func ", qualified_name, "(");
@@ -4229,7 +4235,7 @@ CORD compile_function(env_t *env, ast_t *ast, CORD *staticdefs)
env->code->function_naming = CORD_all(
env->code->function_naming,
CORD_asprintf("register_function(%r, Text(\"%s.tm\"), %ld, Text(%r));\n",
- name, file_base_name(ast->file->filename), get_line_number(ast->file, ast->start), CORD_quoted(text)));
+ name_code, file_base_name(ast->file->filename), get_line_number(ast->file, ast->start), CORD_quoted(text)));
}
return definition;
}
diff --git a/environment.c b/environment.c
index 2e1c4add..393c0c30 100644
--- a/environment.c
+++ b/environment.c
@@ -636,6 +636,36 @@ env_t *for_scope(env_t *env, ast_t *ast)
}
}
+static env_t *get_namespace_by_type(env_t *env, type_t *t)
+{
+ t = value_type(t);
+ switch (t->tag) {
+ case ArrayType: return NULL;
+ case TableType: return NULL;
+ case CStringType: case MomentType:
+ case BoolType: case IntType: case BigIntType: case NumType: case ByteType: {
+ binding_t *b = get_binding(env, CORD_to_const_char_star(type_to_cord(t)));
+ assert(b);
+ return Match(b->type, TypeInfoType)->env;
+ }
+ case TextType: return Match(t, TextType)->env;
+ case StructType: {
+ auto struct_ = Match(t, StructType);
+ return struct_->env;
+ }
+ case EnumType: {
+ auto enum_ = Match(t, EnumType);
+ return enum_->env;
+ }
+ case TypeInfoType: {
+ auto info = Match(t, TypeInfoType);
+ return info->env;
+ }
+ default: break;
+ }
+ return NULL;
+}
+
env_t *namespace_env(env_t *env, const char *namespace_name)
{
binding_t *b = get_binding(env, namespace_name);
@@ -660,82 +690,26 @@ binding_t *get_namespace_binding(env_t *env, ast_t *self, const char *name)
type_t *self_type = get_type(env, self);
if (!self_type)
code_err(self, "I couldn't get this type");
- type_t *cls_type = value_type(self_type);
- switch (cls_type->tag) {
- case ArrayType: return NULL;
- case TableType: return NULL;
- case CStringType: case MomentType:
- case BoolType: case IntType: case BigIntType: case NumType: case ByteType: {
- binding_t *b = get_binding(env, CORD_to_const_char_star(type_to_cord(cls_type)));
- assert(b);
- return get_binding(Match(b->type, TypeInfoType)->env, name);
- }
- case TextType: {
- auto text = Match(cls_type, TextType);
- env_t *text_env = text->env ? text->env : namespace_env(env, text->lang ? text->lang : "Text");
- assert(text_env);
- return get_binding(text_env, name);
- }
- case StructType: {
- auto struct_ = Match(cls_type, StructType);
- return struct_->env ? get_binding(struct_->env, name) : NULL;
- }
- case EnumType: {
- auto enum_ = Match(cls_type, EnumType);
- return enum_->env ? get_binding(enum_->env, name) : NULL;
- }
- case TypeInfoType: {
- auto info = Match(cls_type, TypeInfoType);
- return info->env ? get_binding(info->env, name) : NULL;
- }
- default: break;
- }
- return NULL;
+ env_t *ns_env = get_namespace_by_type(env, self_type);
+ return ns_env ? get_binding(ns_env, name) : NULL;
}
-PUREFUNC binding_t *get_constructor(env_t *env, type_t *t, arg_ast_t *args)
+PUREFUNC binding_t *get_constructor(env_t *env, type_t *t, arg_ast_t *args, type_t *constructed_type)
{
- const char *type_name;
- t = value_type(t);
- switch (t->tag) {
- case TextType: {
- type_name = Match(t, TextType)->lang;
- if (type_name == NULL) type_name = "Text";
- break;
- }
- case StructType: {
- type_name = Match(t, StructType)->name;
- break;
- }
- case EnumType: {
- type_name = Match(t, EnumType)->name;
- break;
- }
- default: {
- type_name = NULL;
- break;
- }
- }
-
- if (!type_name)
- return NULL;
-
- binding_t *typeinfo = get_binding(env, type_name);
- assert(typeinfo && typeinfo->type->tag == TypeInfoType);
- env_t *type_env = Match(typeinfo->type, TypeInfoType)->env;
+ env_t *type_env = get_namespace_by_type(env, t);
Array_t constructors = type_env->namespace->constructors;
// Prioritize exact matches:
for (int64_t i = 0; i < constructors.length; i++) {
binding_t *b = constructors.data + i*constructors.stride;
auto fn = Match(b->type, FunctionType);
- if (is_valid_call(env, fn->args, args, false))
+ if (type_eq(fn->ret, constructed_type) && is_valid_call(env, fn->args, args, false))
return b;
}
// Fall back to promotion:
for (int64_t i = 0; i < constructors.length; i++) {
binding_t *b = constructors.data + i*constructors.stride;
auto fn = Match(b->type, FunctionType);
- if (is_valid_call(env, fn->args, args, true))
+ if (type_eq(fn->ret, constructed_type) && is_valid_call(env, fn->args, args, true))
return b;
}
return NULL;
diff --git a/environment.h b/environment.h
index c3d897b8..b5da5506 100644
--- a/environment.h
+++ b/environment.h
@@ -65,7 +65,7 @@ env_t *namespace_env(env_t *env, const char *namespace_name);
__attribute__((format(printf, 4, 5)))
_Noreturn void compiler_err(file_t *f, const char *start, const char *end, const char *fmt, ...);
binding_t *get_binding(env_t *env, const char *name);
-binding_t *get_constructor(env_t *env, type_t *t, arg_ast_t *args);
+binding_t *get_constructor(env_t *env, type_t *t, arg_ast_t *args, type_t *constructed_type);
void set_binding(env_t *env, const char *name, type_t *type, CORD code);
binding_t *get_namespace_binding(env_t *env, ast_t *self, const char *name);
#define code_err(ast, ...) compiler_err((ast)->file, (ast)->start, (ast)->end, __VA_ARGS__)
diff --git a/test/lang.tm b/test/lang.tm
index 77fd7958..5d8c619e 100644
--- a/test/lang.tm
+++ b/test/lang.tm
@@ -17,6 +17,10 @@ lang HTML:
func paragraph(content:HTML->HTML):
return $HTML"<p>$content</p>"
+struct Bold(text:Text):
+ func HTML(b:Bold -> HTML):
+ return $HTML"<b>$(b.text)</b>"
+
func main():
>> HTML.HEADER
= $HTML"<!DOCTYPE HTML>"
@@ -44,3 +48,8 @@ func main():
>> Text(html)
= '$HTML"Hello I &lt;3 hax!"'
+
+ >> b := Bold("Some <text> with junk")
+ >> $HTML"Your text: $b"
+ = $HTML"Your text: <b>Some &lt;text&gt; with junk</b>"
+
diff --git a/typecheck.c b/typecheck.c
index 26a8ff9f..91343179 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -299,15 +299,17 @@ void bind_statement(env_t *env, ast_t *statement)
auto def = Match(statement, FunctionDef);
const char *name = Match(def->name, Var)->name;
type_t *type = get_function_def_type(env, statement);
- if (env->namespace && env->namespace->parent && env->namespace->name && streq(name, env->namespace->name)) {
- CORD code = CORD_asprintf("%r%ld", namespace_prefix(env, env->namespace), get_line_number(statement->file, statement->start));
+ binding_t *clobber = get_binding(env, name);
+ if (clobber && clobber->type->tag == TypeInfoType && type_eq(Match(clobber->type, TypeInfoType)->type, Match(type, FunctionType)->ret)) {
+ CORD code = CORD_asprintf("%r%r$%ld", namespace_prefix(env, env->namespace), name,
+ get_line_number(statement->file, statement->start));
binding_t binding = {.type=type, .code=code};
Array$insert(&env->namespace->constructors, &binding, I(0), sizeof(binding));
break;
}
- if (get_binding(env, name))
- code_err(def->name, "A %T called '%s' has already been defined", get_binding(env, name)->type, name);
+ if (clobber)
+ code_err(def->name, "A %T called '%s' has already been defined", clobber->type, name);
CORD code = CORD_all(namespace_prefix(env, env->namespace), name);
set_binding(env, name, type, code);
break;
@@ -792,7 +794,7 @@ type_t *get_type(env_t *env, ast_t *ast)
if (fn_type_t->tag == TypeInfoType) {
type_t *t = Match(fn_type_t, TypeInfoType)->type;
- binding_t *constructor = get_constructor(env, t, call->args);
+ binding_t *constructor = get_constructor(env, t, call->args, t);
if (constructor)
return t;
else if (t->tag == StructType || t->tag == IntType || t->tag == BigIntType || t->tag == NumType