diff options
| author | Bruce Hill <bruce@bruce-hill.com> | 2024-07-04 18:00:01 -0400 |
|---|---|---|
| committer | Bruce Hill <bruce@bruce-hill.com> | 2024-07-04 18:00:01 -0400 |
| commit | 6a105fbd801f10bd6c8cee32fd6d45a279f33e1b (patch) | |
| tree | dc03ec9eec0ac65b40d6cb6053d55475dc132fb2 | |
| parent | 78960b1461a8fb184de4ffddf2d2ec4df729fb05 (diff) | |
Add 'defer'
| -rw-r--r-- | ast.c | 1 | ||||
| -rw-r--r-- | ast.h | 4 | ||||
| -rw-r--r-- | compile.c | 52 | ||||
| -rw-r--r-- | environment.h | 8 | ||||
| -rw-r--r-- | parse.c | 10 | ||||
| -rw-r--r-- | test/defer.tm | 41 | ||||
| -rw-r--r-- | typecheck.c | 2 |
7 files changed, 112 insertions, 6 deletions
@@ -139,6 +139,7 @@ CORD ast_to_xml(ast_t *ast) T(Stop, "<Stop>%r</Stop>", data.target) T(PrintStatement, "<PrintStatement>%r</PrintStatement>", ast_list_to_xml(data.to_print)) T(Pass, "<Pass/>") + T(Defer, "<Defer>%r<Defer/>", ast_to_xml(data.body)) T(Return, "<Return>%r</Return>", ast_to_xml(data.value)) T(Extern, "<Extern name=\"%s\">%r</Extern>", data.name, type_ast_to_xml(data.type)) T(StructDef, "<StructDef name=\"%s\">%r<namespace>%r</namespace></StructDef>", data.name, arg_list_to_xml(data.fields), ast_to_xml(data.namespace)) @@ -111,6 +111,7 @@ typedef enum { For, While, If, When, Reduction, Skip, Stop, Pass, + Defer, Return, Extern, StructDef, EnumDef, LangDef, @@ -234,6 +235,9 @@ struct ast_s { } Skip, Stop; struct {} Pass; struct { + ast_t *body; + } Defer; + struct { ast_t *value; } Return; struct { @@ -649,7 +649,10 @@ CORD compile_statement(env_t *env, ast_t *ast) CORD_sprintf(&ctx->skip_label, "skip_%ld", skip_label_count); ++skip_label_count; } - return CORD_all("goto ", ctx->skip_label, ";"); + CORD code = CORD_EMPTY; + for (deferral_t *deferred = env->deferred; deferred && deferred != ctx->deferred; deferred = deferred->next) + code = CORD_all(code, compile_statement(deferred->defer_env, deferred->block)); + return CORD_all(code, "goto ", ctx->skip_label, ";"); } } if (env->loop_ctx) @@ -669,7 +672,10 @@ CORD compile_statement(env_t *env, ast_t *ast) CORD_sprintf(&ctx->stop_label, "stop_%ld", stop_label_count); ++stop_label_count; } - return CORD_all("goto ", ctx->stop_label, ";"); + CORD code = CORD_EMPTY; + for (deferral_t *deferred = env->deferred; deferred && deferred != ctx->deferred; deferred = deferred->next) + code = CORD_all(code, compile_statement(deferred->defer_env, deferred->block)); + return CORD_all(code, "goto ", ctx->stop_label, ";"); } } if (env->loop_ctx) @@ -680,6 +686,25 @@ CORD compile_statement(env_t *env, ast_t *ast) code_err(ast, "I couldn't figure out how to make this stop work!"); } case Pass: return ";"; + case Defer: { + ast_t *body = Match(ast, Defer)->body; + table_t *closed_vars = get_closed_vars(env, FakeAST(Lambda, .args=NULL, .body=body)); + + static int defer_id = 0; + env_t *defer_env = fresh_scope(env); + CORD code = CORD_EMPTY; + for (int64_t i = 1; i <= Table$length(*closed_vars); i++) { + struct { const char *name; binding_t *b; } *entry = Table$entry(*closed_vars, i); + if (entry->b->type->tag == ModuleType) + continue; + CORD defer_name = CORD_asprintf("defer$%d$%s", ++defer_id, entry->name); + code = CORD_all( + code, compile_declaration(entry->b->type, defer_name), " = ", entry->b->code, ";\n"); + set_binding(defer_env, entry->name, new(binding_t, .type=entry->b->type, .code=defer_name)); + } + env->deferred = new(deferral_t, .defer_env=defer_env, .block=body, .next=env->deferred); + return code; + } case PrintStatement: { ast_list_t *to_print = Match(ast, PrintStatement)->to_print; if (!to_print) @@ -700,6 +725,12 @@ CORD compile_statement(env_t *env, ast_t *ast) if (!env->fn_ctx) code_err(ast, "This return statement is not inside any function"); auto ret = Match(ast, Return)->value; assert(env->fn_ctx->return_type); + + CORD code = CORD_EMPTY; + for (deferral_t *deferred = env->deferred; deferred; deferred = deferred->next) { + code = CORD_all(code, compile_statement(deferred->defer_env, deferred->block)); + } + if (ret) { env = with_enum_scope(env, env->fn_ctx->return_type); type_t *ret_t = get_type(env, ret); @@ -707,11 +738,11 @@ CORD compile_statement(env_t *env, ast_t *ast) if (!promote(env, &value, ret_t, env->fn_ctx->return_type)) code_err(ast, "This function expects a return value of type %T, but this return has type %T", env->fn_ctx->return_type, ret_t); - return CORD_all("return ", value, ";"); + return CORD_all(code, "return ", value, ";"); } else { if (env->fn_ctx->return_type->tag != VoidType) code_err(ast, "This function expects a return value of type %T", env->fn_ctx->return_type->tag); - return "return;"; + return CORD_all(code, "return;"); } } case While: { @@ -719,6 +750,7 @@ CORD compile_statement(env_t *env, ast_t *ast) env_t *scope = fresh_scope(env); loop_ctx_t loop_ctx = (loop_ctx_t){ .loop_name="while", + .deferred=scope->deferred, .next=env->loop_ctx, }; scope->loop_ctx = &loop_ctx; @@ -740,6 +772,7 @@ CORD compile_statement(env_t *env, ast_t *ast) .loop_name="for", .key_name=for_->index ? Match(for_->index, Var)->name : CORD_EMPTY, .value_name=for_->value ? Match(for_->value, Var)->name : CORD_EMPTY, + .deferred=body_scope->deferred, .next=body_scope->loop_ctx, }; body_scope->loop_ctx = &loop_ctx; @@ -855,6 +888,7 @@ CORD compile_statement(env_t *env, ast_t *ast) case Block: { ast_list_t *stmts = Match(ast, Block)->statements; CORD code = "{\n"; + deferral_t *prev_deferred = env->deferred; env = fresh_scope(env); for (ast_list_t *stmt = stmts; stmt; stmt = stmt->next) prebind_statement(env, stmt->ast); @@ -862,6 +896,9 @@ CORD compile_statement(env_t *env, ast_t *ast) bind_statement(env, stmt->ast); code = CORD_all(code, compile_statement(env, stmt->ast), "\n"); } + for (deferral_t *deferred = env->deferred; deferred && deferred != prev_deferred; deferred = deferred->next) { + code = CORD_all(code, compile_statement(deferred->defer_env, deferred->block)); + } return CORD_cat(code, "}"); } case Comprehension: { @@ -1439,6 +1476,7 @@ CORD compile(env_t *env, ast_t *ast) return compile(env, stmts->ast); CORD code = "({\n"; + deferral_t *prev_deferred = env->deferred; env = fresh_scope(env); for (ast_list_t *stmt = stmts; stmt; stmt = stmt->next) prebind_statement(env, stmt->ast); @@ -1447,9 +1485,14 @@ CORD compile(env_t *env, ast_t *ast) if (stmt->next) { code = CORD_all(code, compile_statement(env, stmt->ast), "\n"); } else { + // TODO: put defer after evaluating block expression + for (deferral_t *deferred = env->deferred; deferred && deferred != prev_deferred; deferred = deferred->next) { + code = CORD_all(code, compile_statement(deferred->defer_env, deferred->block)); + } code = CORD_all(code, compile(env, stmt->ast), ";\n"); } } + return CORD_cat(code, "})"); } case Min: case Max: { @@ -2036,6 +2079,7 @@ CORD compile(env_t *env, ast_t *ast) } case Use: code_err(ast, "Compiling 'use' as expression!"); case Import: code_err(ast, "Compiling 'import' as expression!"); + case Defer: code_err(ast, "Compiling 'defer' as expression!"); case LinkerDirective: code_err(ast, "Linker directives are not supported yet"); case Extern: code_err(ast, "Externs are not supported as expressions"); case TableEntry: code_err(ast, "Table entries should not be compiled directly"); diff --git a/environment.h b/environment.h index 7ecedd6f..a06f3d42 100644 --- a/environment.h +++ b/environment.h @@ -20,9 +20,16 @@ typedef struct { table_t *closed_vars; } fn_ctx_t; +typedef struct deferral_s { + struct deferral_s *next; + struct env_s *defer_env; + ast_t *block; +} deferral_t; + typedef struct loop_ctx_s { struct loop_ctx_s *next; const char *loop_name, *key_name, *value_name; + deferral_t *deferred; CORD skip_label, stop_label; } loop_ctx_t; @@ -37,6 +44,7 @@ typedef struct env_s { compilation_unit_t *code; fn_ctx_t *fn_ctx; loop_ctx_t *loop_ctx; + deferral_t *deferred; CORD *libname; // Pointer to currently compiling library name (if any) namespace_t *namespace; const char *comprehension_var; @@ -49,7 +49,7 @@ int op_tightness[] = { static const char *keywords[] = { "yes", "xor", "while", "when", "use", "then", "struct", "stop", "skip", "return", "or", "not", "no", "mod1", "mod", "pass", "lang", "import", "inline", "in", "if", - "func", "for", "extern", "enum", "else", "do", "and", "_min_", "_max_", + "func", "for", "extern", "enum", "else", "do", "defer", "and", "_min_", "_max_", NULL, }; @@ -1186,6 +1186,13 @@ PARSER(parse_pass) { return match_word(&pos, "pass") ? NewAST(ctx->file, start, pos, Pass) : NULL; } +PARSER(parse_defer) { + const char *start = pos; + if (!match_word(&pos, "defer")) return NULL; + ast_t *body = expect(ctx, start, &pos, parse_block, "I expected a block to be deferred here"); + return NewAST(ctx->file, start, pos, Defer, .body=body); +} + PARSER(parse_skip) { const char *start = pos; if (!match_word(&pos, "skip")) return NULL; @@ -1266,6 +1273,7 @@ PARSER(parse_term_no_suffix) { || (term=parse_array(ctx, pos)) || (term=parse_reduction(ctx, pos)) || (term=parse_pass(ctx, pos)) + || (term=parse_defer(ctx, pos)) || (term=parse_skip(ctx, pos)) || (term=parse_stop(ctx, pos)) || (term=parse_return(ctx, pos)) diff --git a/test/defer.tm b/test/defer.tm new file mode 100644 index 00000000..fcc14f16 --- /dev/null +++ b/test/defer.tm @@ -0,0 +1,41 @@ +func main(): + x := 123 + nums := @[:Int] + do: + defer: + nums:insert(x) + x = 999 + + >> nums + = @[123] + >> x + = 999 + + defer: + say("All done!") + + for word in ["first", "second", "third"]: + defer: + say("Got {word} deferred") + + if word == "second": + say("<skipped>") + skip + else if word == "third": + say("<stopped>") + stop + + for i in 3: + defer: + say("Inner loop deferred {i}") + + if i == 2: + say("<skipped inner>") + skip + else if i == 3: + say("<stopped inner>") + stop + + say("Made it through inner loop") + + say("Made it through the loop") diff --git a/typecheck.c b/typecheck.c index 8a601c69..86d66252 100644 --- a/typecheck.c +++ b/typecheck.c @@ -738,7 +738,7 @@ type_t *get_type(env_t *env, ast_t *ast) case Return: case Stop: case Skip: case PrintStatement: { return Type(AbortType); } - case Pass: return Type(VoidType); + case Pass: case Defer: return Type(VoidType); case Length: return Type(IntType, .bits=64); case Negative: { ast_t *value = Match(ast, Negative)->value; |
