code / tomo

Lines41.3K C23.7K Markdown9.7K YAML5.0K Tomo2.3K
7 others 763
Python231 Shell230 make212 INI47 Text21 SVG16 Lua6
(798 lines)
1 // Some basic operations defined on AST nodes, mainly converting to
2 // strings for debugging.
4 #include <stdarg.h>
6 #include "ast.h"
7 #include "stdlib/datatypes.h"
8 #include "stdlib/optionals.h"
9 #include "stdlib/tables.h"
10 #include "stdlib/text.h"
12 const int op_tightness[NUM_AST_TAGS] = {
13 [Power] = 9,
14 [Multiply] = 8,
15 [Divide] = 8,
16 [Mod] = 8,
17 [Mod1] = 8,
18 [Plus] = 7,
19 [Minus] = 7,
20 [Concat] = 6,
21 [LeftShift] = 5,
22 [RightShift] = 5,
23 [UnsignedLeftShift] = 5,
24 [UnsignedRightShift] = 5,
25 [Min] = 4,
26 [Max] = 4,
27 [Equals] = 3,
28 [NotEquals] = 3,
29 [LessThan] = 2,
30 [LessThanOrEquals] = 2,
31 [GreaterThan] = 2,
32 [GreaterThanOrEquals] = 2,
33 [Compare] = 2,
34 [And] = 1,
35 [Or] = 1,
36 [Xor] = 1,
37 };
39 const binop_info_t binop_info[NUM_AST_TAGS] = {
40 [Power] = {"power", "^"},
41 [PowerUpdate] = {"power", "^="},
42 [Multiply] = {"times", "*"},
43 [MultiplyUpdate] = {"times", "*="},
44 [Divide] = {"divided_by", "/"},
45 [DivideUpdate] = {"divided_by", "/="},
46 [Mod] = {"modulo", "mod"},
47 [ModUpdate] = {"modulo", "mod="},
48 [Mod1] = {"modulo1", "mod1"},
49 [Mod1Update] = {"modulo1", "mod1="},
50 [Plus] = {"plus", "+"},
51 [PlusUpdate] = {"plus", "+="},
52 [Minus] = {"minus", "-"},
53 [MinusUpdate] = {"minus", "-="},
54 [Concat] = {"concatenated_with", "++"},
55 [ConcatUpdate] = {"concatenated_with", "++="},
56 [LeftShift] = {"left_shifted", "<<"},
57 [LeftShiftUpdate] = {"left_shifted", "<<="},
58 [RightShift] = {"right_shifted", ">>"},
59 [RightShiftUpdate] = {"right_shifted", ">>="},
60 [UnsignedLeftShift] = {"unsigned_left_shifted", NULL},
61 [UnsignedLeftShiftUpdate] = {"unsigned_left_shifted", NULL},
62 [UnsignedRightShift] = {"unsigned_right_shifted", NULL},
63 [UnsignedRightShiftUpdate] = {"unsigned_right_shifted", NULL},
64 [And] = {"bit_and", "and"},
65 [AndUpdate] = {"bit_and", "and="},
66 [Or] = {"bit_or", "or"},
67 [OrUpdate] = {"bit_or", "or="},
68 [Xor] = {"bit_xor", "xor"},
69 [XorUpdate] = {"bit_xor", "xor="},
70 [Equals] = {NULL, "=="},
71 [NotEquals] = {NULL, "!="},
72 [LessThan] = {NULL, "<"},
73 [LessThanOrEquals] = {NULL, "<="},
74 [GreaterThan] = {NULL, ">"},
75 [GreaterThanOrEquals] = {NULL, ">="},
76 [Min] = {NULL, "_min_"},
77 [Max] = {NULL, "_max_"},
78 };
80 static Text_t ast_list_to_sexp(ast_list_t *asts);
81 static Text_t arg_list_to_sexp(arg_ast_t *args);
82 static Text_t arg_defs_to_sexp(arg_ast_t *args);
83 static Text_t when_clauses_to_sexp(when_clause_t *clauses);
84 static Text_t tags_to_sexp(tag_ast_t *tags);
85 static Text_t optional_sexp(const char *tag, ast_t *ast);
86 static Text_t optional_type_sexp(const char *tag, type_ast_t *ast);
88 static Text_t quoted_text(const char *text) {
89 return Text$quoted(Text$from_str(text), false, Text("\""));
92 Text_t ast_list_to_sexp(ast_list_t *asts) {
93 Text_t c = EMPTY_TEXT;
94 for (; asts; asts = asts->next) {
95 c = Texts(c, " ", ast_to_sexp(asts->ast));
97 return c;
100 Text_t arg_defs_to_sexp(arg_ast_t *args) {
101 Text_t c = Text("(args");
102 for (arg_ast_t *arg = args; arg; arg = arg->next) {
103 c = Texts(c, " (arg ", arg->name ? quoted_text(arg->name) : Text("nil"), " ", type_ast_to_sexp(arg->type), " ",
104 ast_to_sexp(arg->value), ")");
106 return Texts(c, ")");
109 Text_t arg_list_to_sexp(arg_ast_t *args) {
110 Text_t c = EMPTY_TEXT;
111 for (arg_ast_t *arg = args; arg; arg = arg->next) {
112 assert(arg->value && !arg->type);
113 if (arg->name) c = Texts(c, " :", arg->name);
114 c = Texts(c, " ", ast_to_sexp(arg->value));
116 return c;
119 Text_t when_clauses_to_sexp(when_clause_t *clauses) {
120 Text_t c = EMPTY_TEXT;
121 for (; clauses; clauses = clauses->next) {
122 c = Texts(c, " (case ", ast_to_sexp(clauses->pattern), " ", ast_to_sexp(clauses->body), ")");
124 return c;
127 Text_t tags_to_sexp(tag_ast_t *tags) {
128 Text_t c = EMPTY_TEXT;
129 for (; tags; tags = tags->next) {
130 c = Texts(c, "(tag \"", tags->name, "\" ", arg_defs_to_sexp(tags->fields), ")");
132 return c;
135 Text_t type_ast_to_sexp(type_ast_t *t) {
136 if (!t) return Text("nil");
138 switch (t->tag) {
139 #define T(type, ...) \
140 case type: { \
141 __typeof(t->__data.type) data = t->__data.type; \
142 (void)data; \
143 return Texts(__VA_ARGS__); \
145 T(UnknownTypeAST, "(UnknownType)");
146 T(VarTypeAST, "(VarType \"", data.name, "\")");
147 T(PointerTypeAST, "(PointerType \"", data.is_stack ? "stack" : "heap", "\" ", type_ast_to_sexp(data.pointed),
148 ")");
149 T(ListTypeAST, "(ListType ", type_ast_to_sexp(data.item), ")");
150 T(TableTypeAST, "(TableType ", type_ast_to_sexp(data.key), " ", type_ast_to_sexp(data.value), ")");
151 T(FunctionTypeAST, "(FunctionType ", arg_defs_to_sexp(data.args), " ", type_ast_to_sexp(data.ret), ")");
152 T(OptionalTypeAST, "(OptionalType ", type_ast_to_sexp(data.type), ")");
153 T(EnumTypeAST, "(EnumType ", data.name, " ", tags_to_sexp(data.tags), ")");
154 #undef T
155 default: return EMPTY_TEXT;
159 Text_t optional_sexp(const char *name, ast_t *ast) {
160 return ast ? Texts(" :", name, " ", ast_to_sexp(ast)) : EMPTY_TEXT;
163 Text_t optional_type_sexp(const char *name, type_ast_t *ast) {
164 return ast ? Texts(" :", name, " ", type_ast_to_sexp(ast)) : EMPTY_TEXT;
167 Text_t ast_to_sexp(ast_t *ast) {
168 if (!ast) return Text("nil");
170 switch (ast->tag) {
171 #define T(type, ...) \
172 case type: { \
173 __typeof(ast->__data.type) data = ast->__data.type; \
174 (void)data; \
175 return Texts(__VA_ARGS__); \
177 T(Unknown, "(Unknown)");
178 T(None, "(None)");
179 T(Bool, "(Bool ", data.b ? "yes" : "no", ")");
180 T(Var, "(Var ", quoted_text(data.name), ")");
181 T(Int, "(Int ", Text$quoted(ast_source(ast), false, Text("\"")), ")");
182 T(Num, "(Num ", Text$quoted(ast_source(ast), false, Text("\"")), ")");
183 T(TextLiteral, Text$quoted(data.text, false, Text("\"")));
184 T(TextJoin, "(Text", data.lang ? Texts(" :lang ", quoted_text(data.lang)) : EMPTY_TEXT,
185 ast_list_to_sexp(data.children), ")");
186 T(Path, "(Path ", quoted_text(data.path), ")");
187 T(Declare, "(Declare ", ast_to_sexp(data.var), " ", type_ast_to_sexp(data.type), " ", ast_to_sexp(data.value),
188 ")");
189 T(Assign, "(Assign (targets ", ast_list_to_sexp(data.targets), ") (values ", ast_list_to_sexp(data.values),
190 "))");
191 #define BINOP(name) T(name, "(" #name " ", ast_to_sexp(data.lhs), " ", ast_to_sexp(data.rhs), ")")
192 BINOP(Power);
193 BINOP(PowerUpdate);
194 BINOP(Multiply);
195 BINOP(MultiplyUpdate);
196 BINOP(Divide);
197 BINOP(DivideUpdate);
198 BINOP(Mod);
199 BINOP(ModUpdate);
200 BINOP(Mod1);
201 BINOP(Mod1Update);
202 BINOP(Plus);
203 BINOP(PlusUpdate);
204 BINOP(Minus);
205 BINOP(MinusUpdate);
206 BINOP(Concat);
207 BINOP(ConcatUpdate);
208 BINOP(LeftShift);
209 BINOP(LeftShiftUpdate);
210 BINOP(RightShift);
211 BINOP(RightShiftUpdate);
212 BINOP(UnsignedLeftShift);
213 BINOP(UnsignedLeftShiftUpdate);
214 BINOP(UnsignedRightShift);
215 BINOP(UnsignedRightShiftUpdate);
216 BINOP(And);
217 BINOP(AndUpdate);
218 BINOP(Or);
219 BINOP(OrUpdate);
220 BINOP(Xor);
221 BINOP(XorUpdate);
222 BINOP(Compare);
223 BINOP(Equals);
224 BINOP(NotEquals);
225 BINOP(LessThan);
226 BINOP(LessThanOrEquals);
227 BINOP(GreaterThan);
228 BINOP(GreaterThanOrEquals);
229 #undef BINOP
230 T(Negative, "(Negative ", ast_to_sexp(data.value), ")");
231 T(Not, "(Not ", ast_to_sexp(data.value), ")");
232 T(HeapAllocate, "(HeapAllocate ", ast_to_sexp(data.value), ")");
233 T(StackReference, "(StackReference ", ast_to_sexp(data.value), ")");
234 T(Min, "(Min ", ast_to_sexp(data.lhs), " ", ast_to_sexp(data.rhs), optional_sexp("key", data.key), ")");
235 T(Max, "(Max ", ast_to_sexp(data.lhs), " ", ast_to_sexp(data.rhs), optional_sexp("key", data.key), ")");
236 T(List, "(List", ast_list_to_sexp(data.items), ")");
237 T(Table, "(Table", optional_sexp("default", data.default_value), optional_sexp("fallback", data.fallback),
238 ast_list_to_sexp(data.entries), ")");
239 T(TableEntry, "(TableEntry ", ast_to_sexp(data.key), " ", ast_to_sexp(data.value), ")");
240 T(Comprehension, "(Comprehension ", ast_to_sexp(data.expr), " (vars", ast_list_to_sexp(data.vars), ") ",
241 ast_to_sexp(data.iter), " ", optional_sexp("filter", data.filter), ")");
242 T(FunctionDef, "(FunctionDef ", ast_to_sexp(data.name), " ", arg_defs_to_sexp(data.args),
243 optional_type_sexp("return", data.ret_type), " ", ast_to_sexp(data.body), ")");
244 T(ConvertDef, "(ConvertDef ", arg_defs_to_sexp(data.args), " ", type_ast_to_sexp(data.ret_type), " ",
245 ast_to_sexp(data.body), ")");
246 T(Lambda, "(Lambda ", arg_defs_to_sexp(data.args), optional_type_sexp("return", data.ret_type), " ",
247 ast_to_sexp(data.body), ")");
248 T(FunctionCall, "(FunctionCall ", ast_to_sexp(data.fn), arg_list_to_sexp(data.args), ")");
249 T(MethodCall, "(MethodCall ", ast_to_sexp(data.self), " ", quoted_text(data.name), arg_list_to_sexp(data.args),
250 ")")
251 T(Block, "(Block", ast_list_to_sexp(data.statements), ")");
252 T(For, "(For (vars", ast_list_to_sexp(data.vars), ") ", ast_to_sexp(data.iter), " ", ast_to_sexp(data.body),
253 " ", ast_to_sexp(data.empty), ")");
254 T(While, "(While ", ast_to_sexp(data.condition), " ", ast_to_sexp(data.body), ")");
255 T(Repeat, "(Repeat ", ast_to_sexp(data.body), ")");
256 T(If, "(If ", ast_to_sexp(data.condition), " ", ast_to_sexp(data.body), optional_sexp("else", data.else_body),
257 ")");
258 T(When, "(When ", ast_to_sexp(data.subject), when_clauses_to_sexp(data.clauses),
259 optional_sexp("else", data.else_body), ")");
260 T(Reduction, "(Reduction ", quoted_text(binop_info[data.op].operator), " ", ast_to_sexp(data.key), " ",
261 ast_to_sexp(data.iter), ")");
262 T(Skip, "(Skip ", quoted_text(data.target), ")");
263 T(Stop, "(Stop ", quoted_text(data.target), ")");
264 T(Pass, "(Pass)");
265 T(Defer, "(Defer ", ast_to_sexp(data.body), ")");
266 T(Return, "(Return ", ast_to_sexp(data.value), ")");
267 T(StructDef, "(StructDef \"", data.name, "\" ", arg_defs_to_sexp(data.fields), " ", ast_to_sexp(data.namespace),
268 ")");
269 T(EnumDef, "(EnumDef \"", data.name, "\" (tags ", tags_to_sexp(data.tags), ") ", ast_to_sexp(data.namespace),
270 ")");
271 T(LangDef, "(LangDef \"", data.name, "\" ", ast_to_sexp(data.namespace), ")");
272 T(Index, "(Index ", ast_to_sexp(data.indexed), " ", ast_to_sexp(data.index), ")");
273 T(FieldAccess, "(FieldAccess ", ast_to_sexp(data.fielded), " \"", data.field, "\")");
274 T(NonOptional, "(NonOptional ", ast_to_sexp(data.value), ")");
275 T(DebugLog, "(DebugLog ", ast_list_to_sexp(data.values), ")");
276 T(Assert, "(Assert ", ast_to_sexp(data.expr), " ", optional_sexp("message", data.message), ")");
277 T(Use, "(Use ", optional_sexp("var", data.var), " ", quoted_text(data.path), ")");
278 T(InlineCCode, "(InlineCCode ", ast_list_to_sexp(data.chunks), optional_type_sexp("type", data.type_ast), ")");
279 T(Metadata, "((Metadata ", Text$quoted(data.key, false, Text("\"")), " ",
280 Text$quoted(data.value, false, Text("\"")), ")");
281 default: errx(1, "S-expressions are not implemented for this AST");
282 #undef T
286 const char *ast_to_sexp_str(ast_t *ast) {
287 return Text$as_c_string(ast_to_sexp(ast));
290 OptionalText_t ast_source(ast_t *ast) {
291 if (ast == NULL || ast->start == NULL || ast->end == NULL) return NONE_TEXT;
292 return Text$from_strn(ast->start, (size_t)(ast->end - ast->start));
295 PUREFUNC bool is_idempotent(ast_t *ast) {
296 switch (ast->tag) {
297 case Int:
298 case Bool:
299 case Num:
300 case Var:
301 case None:
302 case TextLiteral: return true;
303 case Index: {
304 DeclareMatch(index, ast, Index);
305 return is_idempotent(index->indexed) && index->index != NULL && is_idempotent(index->index);
307 case FieldAccess: {
308 DeclareMatch(access, ast, FieldAccess);
309 return is_idempotent(access->fielded);
311 default: return false;
315 void _visit_topologically(ast_t *ast, Table_t definitions, Table_t *visited, Closure_t fn) {
316 void (*visit)(void *, ast_t *) = (void *)fn.fn;
317 if (ast->tag == StructDef) {
318 DeclareMatch(def, ast, StructDef);
319 if (Table$str_get(*visited, def->name)) return;
321 Table$str_set(visited, def->name, (void *)_visit_topologically);
322 for (arg_ast_t *field = def->fields; field; field = field->next) {
323 if (field->type && field->type->tag == VarTypeAST) {
324 const char *field_type_name = Match(field->type, VarTypeAST)->name;
325 ast_t *dependency = Table$str_get(definitions, field_type_name);
326 if (dependency) {
327 _visit_topologically(dependency, definitions, visited, fn);
331 visit(fn.userdata, ast);
332 } else if (ast->tag == EnumDef) {
333 DeclareMatch(def, ast, EnumDef);
334 if (Table$str_get(*visited, def->name)) return;
336 Table$str_set(visited, def->name, (void *)_visit_topologically);
337 for (tag_ast_t *tag = def->tags; tag; tag = tag->next) {
338 for (arg_ast_t *field = tag->fields; field; field = field->next) {
339 if (field->type && field->type->tag == VarTypeAST) {
340 const char *field_type_name = Match(field->type, VarTypeAST)->name;
341 ast_t *dependency = Table$str_get(definitions, field_type_name);
342 if (dependency) {
343 _visit_topologically(dependency, definitions, visited, fn);
348 visit(fn.userdata, ast);
349 } else if (ast->tag == LangDef) {
350 DeclareMatch(def, ast, LangDef);
351 if (Table$str_get(*visited, def->name)) return;
352 visit(fn.userdata, ast);
353 } else {
354 visit(fn.userdata, ast);
358 void visit_topologically(ast_list_t *asts, Closure_t fn) {
359 // Visit each top-level statement in topological order:
360 // - 'use' statements first
361 // - then typedefs
362 // - visiting typedefs' dependencies first
363 // - then function/variable declarations
365 Table_t definitions = EMPTY_TABLE;
366 for (ast_list_t *stmt = asts; stmt; stmt = stmt->next) {
367 if (stmt->ast->tag == StructDef) {
368 DeclareMatch(def, stmt->ast, StructDef);
369 Table$str_set(&definitions, def->name, stmt->ast);
370 } else if (stmt->ast->tag == EnumDef) {
371 DeclareMatch(def, stmt->ast, EnumDef);
372 Table$str_set(&definitions, def->name, stmt->ast);
373 } else if (stmt->ast->tag == LangDef) {
374 DeclareMatch(def, stmt->ast, LangDef);
375 Table$str_set(&definitions, def->name, stmt->ast);
379 void (*visit)(void *, ast_t *) = (void *)fn.fn;
380 Table_t visited = EMPTY_TABLE;
381 // First: 'use' statements in order:
382 for (ast_list_t *stmt = asts; stmt; stmt = stmt->next) {
383 if (stmt->ast->tag == Use) visit(fn.userdata, stmt->ast);
385 // Then typedefs in topological order:
386 for (ast_list_t *stmt = asts; stmt; stmt = stmt->next) {
387 if (stmt->ast->tag == StructDef || stmt->ast->tag == EnumDef || stmt->ast->tag == LangDef)
388 _visit_topologically(stmt->ast, definitions, &visited, fn);
390 // Then everything else in order:
391 for (ast_list_t *stmt = asts; stmt; stmt = stmt->next) {
392 if (!(stmt->ast->tag == StructDef || stmt->ast->tag == EnumDef || stmt->ast->tag == LangDef
393 || stmt->ast->tag == Use)) {
394 visit(fn.userdata, stmt->ast);
399 CONSTFUNC bool is_binary_operation(ast_t *ast) {
400 switch (ast->tag) {
401 case Min:
402 case Max:
403 case BINOP_CASES: return true;
404 default: return false;
408 CONSTFUNC bool is_update_assignment(ast_t *ast) {
409 switch (ast->tag) {
410 case PowerUpdate:
411 case MultiplyUpdate:
412 case DivideUpdate:
413 case ModUpdate:
414 case Mod1Update:
415 case PlusUpdate:
416 case MinusUpdate:
417 case ConcatUpdate:
418 case LeftShiftUpdate:
419 case UnsignedLeftShiftUpdate:
420 case RightShiftUpdate:
421 case UnsignedRightShiftUpdate:
422 case AndUpdate:
423 case OrUpdate:
424 case XorUpdate: return true;
425 default: return false;
429 CONSTFUNC ast_e binop_tag(ast_e tag) {
430 switch (tag) {
431 case PowerUpdate: return Power;
432 case MultiplyUpdate: return Multiply;
433 case DivideUpdate: return Divide;
434 case ModUpdate: return Mod;
435 case Mod1Update: return Mod1;
436 case PlusUpdate: return Plus;
437 case MinusUpdate: return Minus;
438 case ConcatUpdate: return Concat;
439 case LeftShiftUpdate: return LeftShift;
440 case UnsignedLeftShiftUpdate: return UnsignedLeftShift;
441 case RightShiftUpdate: return RightShift;
442 case UnsignedRightShiftUpdate: return UnsignedRightShift;
443 case AndUpdate: return And;
444 case OrUpdate: return Or;
445 case XorUpdate: return Xor;
446 default: return Unknown;
450 static void ast_visit_list(ast_list_t *ast_list, visit_behavior_t (*visitor)(ast_t *, void *), void *userdata) {
451 for (ast_list_t *ast = ast_list; ast; ast = ast->next)
452 ast_visit(ast->ast, visitor, userdata);
455 static void ast_visit_args(arg_ast_t *args, visit_behavior_t (*visitor)(ast_t *, void *), void *userdata) {
456 for (arg_ast_t *arg = args; arg; arg = arg->next)
457 ast_visit(arg->value, visitor, userdata);
460 void ast_visit(ast_t *ast, visit_behavior_t (*visitor)(ast_t *, void *), void *userdata) {
461 if (!ast) return;
462 if (visitor(ast, userdata) == VISIT_STOP) return;
464 switch (ast->tag) {
465 case Unknown:
466 case None:
467 case Bool:
468 case Var:
469 case Int:
470 case Num:
471 case Path:
472 case TextLiteral:
473 case Metadata: return;
474 case TextJoin: ast_visit_list(Match(ast, TextJoin)->children, visitor, userdata); return;
475 case Declare: {
476 DeclareMatch(decl, ast, Declare);
477 ast_visit(decl->var, visitor, userdata);
478 ast_visit(decl->value, visitor, userdata);
479 return;
481 case Assign: {
482 DeclareMatch(assign, ast, Assign);
483 ast_visit_list(assign->targets, visitor, userdata);
484 ast_visit_list(assign->values, visitor, userdata);
485 return;
487 case BINOP_CASES: {
488 binary_operands_t op = BINARY_OPERANDS(ast);
489 ast_visit(op.lhs, visitor, userdata);
490 ast_visit(op.rhs, visitor, userdata);
491 return;
493 case Negative: {
494 ast_visit(Match(ast, Negative)->value, visitor, userdata);
495 return;
497 case Not: {
498 ast_visit(Match(ast, Not)->value, visitor, userdata);
499 return;
501 case HeapAllocate: {
502 ast_visit(Match(ast, HeapAllocate)->value, visitor, userdata);
503 return;
505 case StackReference: {
506 ast_visit(Match(ast, StackReference)->value, visitor, userdata);
507 return;
509 case Min: {
510 DeclareMatch(min, ast, Min);
511 ast_visit(min->lhs, visitor, userdata);
512 ast_visit(min->key, visitor, userdata);
513 ast_visit(min->rhs, visitor, userdata);
514 return;
516 case Max: {
517 DeclareMatch(max, ast, Max);
518 ast_visit(max->lhs, visitor, userdata);
519 ast_visit(max->key, visitor, userdata);
520 ast_visit(max->rhs, visitor, userdata);
521 return;
523 case List: {
524 ast_visit_list(Match(ast, List)->items, visitor, userdata);
525 return;
527 case Table: {
528 DeclareMatch(table, ast, Table);
529 ast_visit_list(table->entries, visitor, userdata);
530 ast_visit(table->default_value, visitor, userdata);
531 ast_visit(table->fallback, visitor, userdata);
532 return;
534 case TableEntry: {
535 DeclareMatch(entry, ast, TableEntry);
536 ast_visit(entry->key, visitor, userdata);
537 ast_visit(entry->value, visitor, userdata);
538 return;
540 case Comprehension: {
541 DeclareMatch(comp, ast, Comprehension);
542 ast_visit(comp->expr, visitor, userdata);
543 ast_visit_list(comp->vars, visitor, userdata);
544 ast_visit(comp->iter, visitor, userdata);
545 ast_visit(comp->filter, visitor, userdata);
546 return;
548 case FunctionDef: {
549 DeclareMatch(def, ast, FunctionDef);
550 ast_visit(def->name, visitor, userdata);
551 ast_visit_args(def->args, visitor, userdata);
552 ast_visit(def->body, visitor, userdata);
553 return;
555 case ConvertDef: {
556 DeclareMatch(def, ast, ConvertDef);
557 ast_visit_args(def->args, visitor, userdata);
558 ast_visit(def->body, visitor, userdata);
559 return;
561 case Lambda: {
562 DeclareMatch(lambda, ast, Lambda);
563 ast_visit_args(lambda->args, visitor, userdata);
564 ast_visit(lambda->body, visitor, userdata);
565 return;
567 case FunctionCall: {
568 DeclareMatch(call, ast, FunctionCall);
569 ast_visit(call->fn, visitor, userdata);
570 ast_visit_args(call->args, visitor, userdata);
571 return;
573 case MethodCall: {
574 DeclareMatch(call, ast, MethodCall);
575 ast_visit(call->self, visitor, userdata);
576 ast_visit_args(call->args, visitor, userdata);
577 return;
579 case Block: {
580 ast_visit_list(Match(ast, Block)->statements, visitor, userdata);
581 return;
583 case For: {
584 DeclareMatch(for_, ast, For);
585 ast_visit_list(for_->vars, visitor, userdata);
586 ast_visit(for_->iter, visitor, userdata);
587 ast_visit(for_->body, visitor, userdata);
588 ast_visit(for_->empty, visitor, userdata);
589 return;
591 case While: {
592 DeclareMatch(while_, ast, While);
593 ast_visit(while_->condition, visitor, userdata);
594 ast_visit(while_->body, visitor, userdata);
595 return;
597 case Repeat: {
598 ast_visit(Match(ast, Repeat)->body, visitor, userdata);
599 return;
601 case If: {
602 DeclareMatch(if_, ast, If);
603 ast_visit(if_->condition, visitor, userdata);
604 ast_visit(if_->body, visitor, userdata);
605 ast_visit(if_->else_body, visitor, userdata);
606 return;
608 case When: {
609 DeclareMatch(when, ast, When);
610 ast_visit(when->subject, visitor, userdata);
611 for (when_clause_t *clause = when->clauses; clause; clause = clause->next) {
612 ast_visit(clause->pattern, visitor, userdata);
613 ast_visit(clause->body, visitor, userdata);
615 ast_visit(when->else_body, visitor, userdata);
616 return;
618 case Reduction: {
619 DeclareMatch(reduction, ast, Reduction);
620 ast_visit(reduction->key, visitor, userdata);
621 ast_visit(reduction->iter, visitor, userdata);
622 return;
624 case Skip:
625 case Stop:
626 case Pass: return;
627 case Defer: {
628 ast_visit(Match(ast, Defer)->body, visitor, userdata);
629 return;
631 case Return: {
632 ast_visit(Match(ast, Return)->value, visitor, userdata);
633 return;
635 case StructDef: {
636 DeclareMatch(def, ast, StructDef);
637 ast_visit_args(def->fields, visitor, userdata);
638 ast_visit(def->namespace, visitor, userdata);
639 return;
641 case EnumDef: {
642 DeclareMatch(def, ast, EnumDef);
643 for (tag_ast_t *tag = def->tags; tag; tag = tag->next)
644 ast_visit_args(tag->fields, visitor, userdata);
645 ast_visit(def->namespace, visitor, userdata);
646 return;
648 case LangDef: {
649 ast_visit(Match(ast, LangDef)->namespace, visitor, userdata);
650 return;
652 case Index: {
653 DeclareMatch(index, ast, Index);
654 ast_visit(index->indexed, visitor, userdata);
655 ast_visit(index->index, visitor, userdata);
656 return;
658 case FieldAccess: {
659 ast_visit(Match(ast, FieldAccess)->fielded, visitor, userdata);
660 return;
662 case NonOptional: {
663 ast_visit(Match(ast, NonOptional)->value, visitor, userdata);
664 return;
666 case DebugLog: {
667 DeclareMatch(show, ast, DebugLog);
668 ast_visit_list(show->values, visitor, userdata);
669 return;
671 case Assert: {
672 DeclareMatch(assert, ast, Assert);
673 ast_visit(assert->expr, visitor, userdata);
674 ast_visit(assert->message, visitor, userdata);
675 return;
677 case Use: {
678 ast_visit(Match(ast, Use)->var, visitor, userdata);
679 return;
681 case InlineCCode: {
682 ast_visit_list(Match(ast, InlineCCode)->chunks, visitor, userdata);
683 return;
685 default: errx(1, "Visiting is not supported for this AST: %s", Text$as_c_string(ast_to_sexp(ast)));
686 #undef T
690 static void _recursive_type_ast_visit(type_ast_t *ast, void *userdata) {
691 if (ast == NULL) return;
693 visit_behavior_t (*visit)(type_ast_t *, void *) = ((Closure_t *)userdata)->fn;
694 void *visitor_userdata = ((Closure_t *)userdata)->userdata;
695 if (visit(ast, visitor_userdata) == VISIT_STOP) return;
697 switch (ast->tag) {
698 case UnknownTypeAST:
699 case VarTypeAST: break;
700 case PointerTypeAST: {
701 _recursive_type_ast_visit(Match(ast, PointerTypeAST)->pointed, userdata);
702 break;
704 case ListTypeAST: {
705 _recursive_type_ast_visit(Match(ast, ListTypeAST)->item, userdata);
706 break;
708 case TableTypeAST: {
709 DeclareMatch(table, ast, TableTypeAST);
710 _recursive_type_ast_visit(table->key, userdata);
711 _recursive_type_ast_visit(table->value, userdata);
712 break;
714 case FunctionTypeAST: {
715 DeclareMatch(fn, ast, FunctionTypeAST);
716 for (arg_ast_t *arg = fn->args; arg; arg = arg->next)
717 _recursive_type_ast_visit(arg->type, userdata);
718 _recursive_type_ast_visit(fn->ret, userdata);
719 break;
721 case OptionalTypeAST: {
722 _recursive_type_ast_visit(Match(ast, OptionalTypeAST)->type, userdata);
723 break;
725 case EnumTypeAST: {
726 for (tag_ast_t *tag = Match(ast, EnumTypeAST)->tags; tag; tag = tag->next) {
727 for (arg_ast_t *field = tag->fields; field; field = field->next) {
728 _recursive_type_ast_visit(field->type, userdata);
731 break;
733 default: errx(1, "Invalid type AST");
737 static visit_behavior_t _type_ast_visit(ast_t *ast, void *userdata) {
738 switch (ast->tag) {
739 case Declare: {
740 _recursive_type_ast_visit(Match(ast, Declare)->type, userdata);
741 break;
743 case FunctionDef: {
744 for (arg_ast_t *arg = Match(ast, FunctionDef)->args; arg; arg = arg->next)
745 _recursive_type_ast_visit(arg->type, userdata);
746 _recursive_type_ast_visit(Match(ast, FunctionDef)->ret_type, userdata);
747 break;
749 case Lambda: {
750 for (arg_ast_t *arg = Match(ast, Lambda)->args; arg; arg = arg->next)
751 _recursive_type_ast_visit(arg->type, userdata);
752 _recursive_type_ast_visit(Match(ast, Lambda)->ret_type, userdata);
753 break;
755 case ConvertDef: {
756 for (arg_ast_t *arg = Match(ast, ConvertDef)->args; arg; arg = arg->next)
757 _recursive_type_ast_visit(arg->type, userdata);
758 _recursive_type_ast_visit(Match(ast, ConvertDef)->ret_type, userdata);
759 break;
761 case StructDef: {
762 for (arg_ast_t *field = Match(ast, StructDef)->fields; field; field = field->next)
763 _recursive_type_ast_visit(field->type, userdata);
764 break;
766 case EnumDef: {
767 for (tag_ast_t *tag = Match(ast, EnumDef)->tags; tag; tag = tag->next) {
768 for (arg_ast_t *field = tag->fields; field; field = field->next) {
769 _recursive_type_ast_visit(field->type, userdata);
772 break;
774 case InlineCCode: {
775 _recursive_type_ast_visit(Match(ast, InlineCCode)->type_ast, userdata);
776 break;
778 default: break;
780 return VISIT_PROCEED;
783 void type_ast_visit(ast_t *ast, visit_behavior_t (*visitor)(type_ast_t *, void *), void *userdata) {
784 Closure_t fn = {.fn = visitor, .userdata = userdata};
785 ast_visit(ast, _type_ast_visit, &fn);
788 OptionalText_t ast_metadata(ast_t *ast, const char *key) {
789 if (ast->tag != Block) return NONE_TEXT;
790 Text_t key_text = Text$from_str(key);
791 for (ast_list_t *stmt = Match(ast, Block)->statements; stmt; stmt = stmt->next) {
792 if (stmt->ast->tag == Metadata) {
793 DeclareMatch(m, stmt->ast, Metadata);
794 if (Text$equal_values(m->key, key_text)) return m->value;
797 return NONE_TEXT;