diff options
| author | Bruce Hill <bruce@bruce-hill.com> | 2024-08-19 13:23:02 -0400 |
|---|---|---|
| committer | Bruce Hill <bruce@bruce-hill.com> | 2024-08-19 13:23:02 -0400 |
| commit | 4e732a718dc57f3c06af5ca9e43e4744b87ba72d (patch) | |
| tree | 878cce5dfcb2a7c53df0dcbafdac2727ed29c052 | |
| parent | 3ddaf9250586db0cf0d3e40106c836bb7ec33add (diff) | |
Restructure things so that DSL constructors do proper escaping
| -rw-r--r-- | compile.c | 53 | ||||
| -rw-r--r-- | environment.c | 21 | ||||
| -rw-r--r-- | environment.h | 1 |
3 files changed, 50 insertions, 25 deletions
@@ -1778,7 +1778,6 @@ CORD compile(env_t *env, ast_t *ast) type_t *text_t = Table$str_get(*env->types, lang ? lang : "Text"); if (!text_t || text_t->tag != TextType) code_err(ast, "%s is not a valid text language name", lang); - env_t *lang_env = lang ? Match(get_binding(env, lang)->type, TypeInfoType)->env : NULL; ast_list_t *chunks = Match(ast, TextJoin)->children; if (!chunks) { return "(CORD)CORD_EMPTY"; @@ -1793,25 +1792,14 @@ CORD compile(env_t *env, ast_t *ast) chunk_code = compile(env, chunk->ast); } else if (chunk_t->tag == TextType && streq(Match(chunk_t, TextType)->lang, lang)) { chunk_code = compile(env, chunk->ast); - } else if (lang && lang_env) { - // Get conversion function: - chunk_code = compile(env, chunk->ast); - for (int64_t i = 1; i <= Table$length(*lang_env->locals); i++) { - struct {const char *name; binding_t *b; } *entry = Table$entry(*lang_env->locals, i); - if (entry->b->type->tag != FunctionType) continue; - if (!(streq(entry->name, "escape") || strncmp(entry->name, "escape_", strlen("escape_")) == 0)) - continue; - auto fn = Match(entry->b->type, FunctionType); - if (!fn->args || fn->args->next) continue; - if (fn->ret->tag != TextType || !streq(Match(fn->ret, TextType)->lang, lang)) - continue; - if (!promote(env, &chunk_code, chunk_t, fn->args->type)) - continue; - chunk_code = CORD_all(entry->b->code, "(", chunk_code, ")"); - goto found_conversion; - } - code_err(chunk->ast, "I don't know how to convert %T to %T", chunk_t, text_t); - found_conversion:; + } else if (lang) { + binding_t *esc = get_lang_escape_function(env, lang, chunk_t); + if (!esc) + code_err(chunk->ast, "I don't know how to convert %T to %T", chunk_t, text_t); + + arg_t *arg_spec = Match(esc->type, FunctionType)->args; + arg_ast_t *args = new(arg_ast_t, .value=chunk->ast); + chunk_code = CORD_all(esc->code, "(", compile_arguments(env, ast, arg_spec, args), ")"); } else { chunk_code = compile_string(env, chunk->ast, "no"); } @@ -2498,6 +2486,7 @@ CORD compile(env_t *env, ast_t *ast) fn_t = Type(FunctionType, .args=Match(t, StructType)->fields, .ret=t); return CORD_all("((", compile_type(t), "){", compile_arguments(env, ast, Match(fn_t, FunctionType)->args, call->args), "})"); } else if (t->tag == NumType || t->tag == BigIntType) { + if (!call->args) code_err(ast, "This constructor needs a value"); type_t *actual = get_type(env, call->args->value); arg_t *args = new(arg_t, .name="i", .type=actual); // No truncation argument CORD arg_code = compile_arguments(env, ast, args, call->args); @@ -2509,11 +2498,25 @@ CORD compile(env_t *env, ast_t *ast) CORD arg_code = compile_arguments(env, ast, args, call->args); return CORD_all(type_to_cord(actual), "_to_", type_to_cord(t), "(", arg_code, ")"); } else if (t->tag == TextType) { - // Text constructor: - if (!call->args || call->args->next) - code_err(call->fn, "This constructor takes exactly 1 argument"); - type_t *actual = get_type(env, call->args->value); - return expr_as_text(env, compile(env, call->args->value), actual, "no"); + if (!call->args) code_err(ast, "This constructor needs a value"); + const char *lang = Match(t, TextType)->lang; + if (lang) { // Escape for DSL + type_t *first_type = get_type(env, call->args->value); + binding_t *esc = get_lang_escape_function(env, lang, first_type); + if (!esc) + code_err(ast, "I don't know how to convert %T to %T", first_type, t); + + arg_t *arg_spec = Match(esc->type, FunctionType)->args; + return CORD_all(esc->code, "(", compile_arguments(env, ast, arg_spec, call->args), ")"); + } else { + // Text constructor: + if (!call->args || call->args->next) + code_err(call->fn, "This constructor takes exactly 1 argument"); + type_t *actual = get_type(env, call->args->value); + if (type_eq(actual, t)) + return compile(env, call->args->value); + return expr_as_text(env, compile(env, call->args->value), actual, "no"); + } } else if (t->tag == CStringType) { // C String constructor: if (!call->args || call->args->next) diff --git a/environment.c b/environment.c index 4345770d..250e3153 100644 --- a/environment.c +++ b/environment.c @@ -531,6 +531,27 @@ binding_t *get_namespace_binding(env_t *env, ast_t *self, const char *name) return NULL; } +binding_t *get_lang_escape_function(env_t *env, const char *lang_name, type_t *type_to_escape) +{ + binding_t *typeinfo = get_binding(env, lang_name); + assert(typeinfo && typeinfo->type->tag == TypeInfoType); + env_t *lang_env = Match(typeinfo->type, TypeInfoType)->env; + for (int64_t i = 1; i <= Table$length(*lang_env->locals); i++) { + struct {const char *name; binding_t *b; } *entry = Table$entry(*lang_env->locals, i); + if (entry->b->type->tag != FunctionType) continue; + if (!(streq(entry->name, "escape") || strncmp(entry->name, "escape_", strlen("escape_")) == 0)) + continue; + auto fn = Match(entry->b->type, FunctionType); + if (!fn->args || fn->args->next) continue; + if (fn->ret->tag != TextType || !streq(Match(fn->ret, TextType)->lang, lang_name)) + continue; + if (!can_promote(type_to_escape, fn->args->type)) + continue; + return entry->b; + } + return NULL; +} + void set_binding(env_t *env, const char *name, binding_t *binding) { if (name && binding) diff --git a/environment.h b/environment.h index 0f5d34ef..b69e0a9e 100644 --- a/environment.h +++ b/environment.h @@ -71,6 +71,7 @@ env_t *namespace_env(env_t *env, const char *namespace_name); __attribute__((noreturn)) void compiler_err(file_t *f, const char *start, const char *end, const char *fmt, ...); binding_t *get_binding(env_t *env, const char *name); +binding_t *get_lang_escape_function(env_t *env, const char *lang_name, type_t *type_to_escape); void set_binding(env_t *env, const char *name, binding_t *binding); binding_t *get_namespace_binding(env_t *env, ast_t *self, const char *name); #define code_err(ast, ...) compiler_err((ast)->file, (ast)->start, (ast)->end, __VA_ARGS__) |
