aboutsummaryrefslogtreecommitdiff
path: root/compile.c
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2024-07-13 17:17:58 -0400
committerBruce Hill <bruce@bruce-hill.com>2024-07-13 17:17:58 -0400
commit445f79cb70e72698283539b65e43fc71a47ad311 (patch)
tree9a1b0b027a1957fc0f6351e21ef53ce7ff53259a /compile.c
parent3db57b4d2e16ab25fbd07401ec7b3a738f8dae8a (diff)
Add iterator functions
Diffstat (limited to 'compile.c')
-rw-r--r--compile.c129
1 files changed, 90 insertions, 39 deletions
diff --git a/compile.c b/compile.c
index cac7a345..9db1642e 100644
--- a/compile.c
+++ b/compile.c
@@ -18,6 +18,7 @@ static CORD compile_to_pointer_depth(env_t *env, ast_t *ast, int64_t target_dept
static env_t *with_enum_scope(env_t *env, type_t *t);
static CORD compile_math_method(env_t *env, ast_t *ast, binop_e op, ast_t *lhs, ast_t *rhs, type_t *required_type);
static CORD compile_string(env_t *env, ast_t *ast, CORD color);
+static CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t *call_args);
static bool promote(env_t *env, CORD *code, type_t *actual, type_t *needed)
{
@@ -647,8 +648,11 @@ CORD compile_statement(env_t *env, ast_t *ast)
case Skip: {
const char *target = Match(ast, Skip)->target;
for (loop_ctx_t *ctx = env->loop_ctx; ctx; ctx = ctx->next) {
- if (!target || CORD_cmp(target, ctx->loop_name) == 0
- || CORD_cmp(target, ctx->key_name) == 0 || CORD_cmp(target, ctx->value_name) == 0) {
+ bool matched = !target || CORD_cmp(target, ctx->loop_name) == 0;
+ for (ast_list_t *var = ctx->loop_vars; var && !matched; var = var->next)
+ matched = (CORD_cmp(target, Match(var->ast, Var)->name) == 0);
+
+ if (matched) {
if (!ctx->skip_label) {
static int64_t skip_label_count = 1;
CORD_sprintf(&ctx->skip_label, "skip_%ld", skip_label_count);
@@ -670,8 +674,11 @@ CORD compile_statement(env_t *env, ast_t *ast)
case Stop: {
const char *target = Match(ast, Stop)->target;
for (loop_ctx_t *ctx = env->loop_ctx; ctx; ctx = ctx->next) {
- if (!target || CORD_cmp(target, ctx->loop_name) == 0
- || CORD_cmp(target, ctx->key_name) == 0 || CORD_cmp(target, ctx->value_name) == 0) {
+ bool matched = !target || CORD_cmp(target, ctx->loop_name) == 0;
+ for (ast_list_t *var = ctx->loop_vars; var && !matched; var = var->next)
+ matched = (CORD_cmp(target, Match(var->ast, Var)->name) == 0);
+
+ if (matched) {
if (!ctx->stop_label) {
static int64_t stop_label_count = 1;
CORD_sprintf(&ctx->stop_label, "stop_%ld", stop_label_count);
@@ -778,8 +785,7 @@ CORD compile_statement(env_t *env, ast_t *ast)
env_t *body_scope = for_scope(env, ast);
loop_ctx_t loop_ctx = (loop_ctx_t){
.loop_name="for",
- .key_name=for_->index ? Match(for_->index, Var)->name : CORD_EMPTY,
- .value_name=for_->value ? Match(for_->value, Var)->name : CORD_EMPTY,
+ .loop_vars=for_->vars,
.deferred=body_scope->deferred,
.next=body_scope->loop_ctx,
};
@@ -792,8 +798,20 @@ CORD compile_statement(env_t *env, ast_t *ast)
switch (iter_t->tag) {
case ArrayType: {
type_t *item_t = Match(iter_t, ArrayType)->item_type;
- CORD index = for_->index ? compile(env, for_->index) : "i";
- CORD value = compile(env, for_->value);
+ CORD index = "i";
+ CORD value = "value";
+ if (for_->vars) {
+ if (for_->vars->next) {
+ if (for_->vars->next->next)
+ code_err(for_->vars->next->next->ast, "This is too many variables for this loop");
+
+ index = compile(env, for_->vars->ast);
+ value = compile(env, for_->vars->next->ast);
+ } else {
+ value = compile(env, for_->vars->ast);
+ }
+ }
+
CORD array = is_idempotent(for_->iter) ? compile(env, for_->iter) : "arr";
CORD loop = CORD_all("ARRAY_INCREF(", array, ");\n"
"for (int64_t ", index, " = 1; ", index, " <= ", array, ".length; ++", index, ") {\n",
@@ -814,18 +832,30 @@ CORD compile_statement(env_t *env, ast_t *ast)
CORD table = is_idempotent(for_->iter) ? compile(env, for_->iter) : "table";
CORD loop = CORD_all("ARRAY_INCREF(", table, ".entries);\n"
"for (int64_t i = 0; i < ",table,".entries.length; ++i) {\n");
- if (for_->index) {
- loop = CORD_all(loop, compile_type(key_t), " ", compile(env, for_->index), " = *(", compile_type(key_t), "*)(",
- table,".entries.data + i*", table, ".entries.stride);\n");
+ CORD key = CORD_EMPTY, value = CORD_EMPTY;
+ if (for_->vars) {
+ if (for_->vars->next) {
+ if (for_->vars->next->next)
+ code_err(for_->vars->next->next->ast, "This is too many variables for this loop");
+
+ key = compile(env, for_->vars->ast);
+ value = compile(env, for_->vars->next->ast);
+ } else {
+ key = compile(env, for_->vars->ast);
+ }
+ }
+
+ if (key) {
+ loop = CORD_all(loop, compile_type(key_t), " ", key, " = *(", compile_type(key_t), "*)(",
+ table,".entries.data + i*", table, ".entries.stride);\n");
+ }
+ if (value) {
size_t value_offset = type_size(key_t);
if (type_align(value_t) > 1 && value_offset % type_align(value_t))
value_offset += type_align(value_t) - (value_offset % type_align(value_t)); // padding
- loop = CORD_all(loop, compile_type(value_t), " ", compile(env, for_->value), " = *(", compile_type(value_t), "*)(",
+ loop = CORD_all(loop, compile_type(value_t), " ", value, " = *(", compile_type(value_t), "*)(",
table,".entries.data + i*", table, ".entries.stride + ", heap_strf("%zu", value_offset), ");\n");
- } else {
- loop = CORD_all(loop, compile_type(key_t), " ", compile(env, for_->value), " = *(", compile_type(key_t), "*)(",
- table,".entries.data + i*", table, ".entries.stride);\n");
}
loop = CORD_all(loop, body, "\n}");
if (for_->empty)
@@ -836,19 +866,9 @@ CORD compile_statement(env_t *env, ast_t *ast)
return loop;
}
case IntType: {
- CORD value = compile(env, for_->value);
+ CORD value = for_->vars ? compile(env, for_->vars->ast) : "i";
CORD n = compile(env, for_->iter);
- CORD index = for_->index ? compile(env, for_->index) : CORD_EMPTY;
- if (for_->empty && index) {
- return CORD_all(
- "{\n"
- "int64_t n = ", n, ";\n"
- "if (n > 0) {\n"
- "for (int64_t ", index, " = 1, ", value, "; (", value, "=", index,") <= n; ++", index, ") {\n"
- "\t", body, "\n}"
- "\n} else ", compile_statement(env, for_->empty),
- stop, "\n}");
- } else if (for_->empty) {
+ if (for_->empty) {
return CORD_all(
"{\n"
"int64_t n = ", n, ";\n"
@@ -859,13 +879,6 @@ CORD compile_statement(env_t *env, ast_t *ast)
"\n} else ", compile_statement(env, for_->empty),
stop,
"\n}");
- } else if (index) {
- return CORD_all(
- "for (int64_t ", value, ", ", index, " = 1, n = ", n, "; (", value, "=", index,") <= n; ++", value, ") {\n"
- "\t", body,
- "\n}",
- stop,
- "\n");
} else {
return CORD_all(
"for (int64_t ", value, " = 1, n = ", compile(env, for_->iter), "; ", value, " <= n; ++", value, ") {\n"
@@ -875,6 +888,44 @@ CORD compile_statement(env_t *env, ast_t *ast)
"\n");
}
}
+ case FunctionType: case ClosureType: {
+ CORD code = "{\n";
+ auto fn = iter_t->tag == ClosureType ? Match(Match(iter_t, ClosureType)->fn, FunctionType) : Match(iter_t, FunctionType);
+ arg_t *next_arg = fn->args;
+ for (ast_list_t *var = for_->vars; var; var = var->next) {
+ const char *name = Match(var->ast, Var)->name;
+ type_t *t = Match(get_arg_type(env, next_arg), PointerType)->pointed;
+ code = CORD_all(code, compile_declaration(t, CORD_cat("$", name)), ";\n");
+ }
+
+ code = CORD_all(code, compile_declaration(iter_t, "next"), " = ", compile(env, for_->iter), ";\n");
+
+ CORD next_fn;
+ if (iter_t->tag == ClosureType) {
+ type_t *fn_t = Match(iter_t, ClosureType)->fn;
+ arg_t *closure_fn_args = NULL;
+ for (arg_t *arg = Match(fn_t, FunctionType)->args; arg; arg = arg->next)
+ closure_fn_args = new(arg_t, .name=arg->name, .type=arg->type, .default_val=arg->default_val, .next=closure_fn_args);
+ closure_fn_args = new(arg_t, .name="userdata", .type=Type(PointerType, .pointed=Type(MemoryType)), .next=closure_fn_args);
+ REVERSE_LIST(closure_fn_args);
+ CORD fn_type_code = compile_type(Type(FunctionType, .args=closure_fn_args, .ret=Match(fn_t, FunctionType)->ret));
+ next_fn = CORD_all("((", fn_type_code, ")next.fn)");
+ } else {
+ next_fn = "next";
+ }
+
+ code = CORD_all(code, "for(; ", next_fn, "(");
+ for (ast_list_t *var = for_->vars; var; var = var->next) {
+ const char *name = Match(var->ast, Var)->name;
+ code = CORD_all(code, "&$", name);
+ if (var->next || iter_t->tag == ClosureType)
+ code = CORD_all(code, ", ");
+ }
+ if (iter_t->tag == ClosureType)
+ code = CORD_all(code, "next.userdata");
+ code = CORD_all(code, "); ) {\n\t", body, "}\n", stop, "\n}\n");
+ return code;
+ }
default: code_err(for_->iter, "Iteration is not implemented for type: %T", iter_t);
}
}
@@ -914,7 +965,7 @@ CORD compile_statement(env_t *env, ast_t *ast)
assert(env->comprehension_var);
if (comp->expr->tag == Comprehension) { // Nested comprehension
ast_t *body = comp->filter ? WrapAST(ast, If, .condition=comp->filter, .body=comp->expr) : comp->expr;
- ast_t *loop = WrapAST(ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body);
+ ast_t *loop = WrapAST(ast, For, .vars=comp->vars, .iter=comp->iter, .body=body);
return compile_statement(env, loop);
} else if (comp->expr->tag == TableEntry) { // Table comprehension
auto e = Match(comp->expr, TableEntry);
@@ -922,14 +973,14 @@ CORD compile_statement(env_t *env, ast_t *ast)
.args=new(arg_ast_t, .value=e->key, .next=new(arg_ast_t, .value=e->value)));
if (comp->filter)
body = WrapAST(body, If, .condition=comp->filter, .body=body);
- ast_t *loop = WrapAST(ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body);
+ ast_t *loop = WrapAST(ast, For, .vars=comp->vars, .iter=comp->iter, .body=body);
return compile_statement(env, loop);
} else { // Array comprehension
ast_t *body = WrapAST(comp->expr, MethodCall, .name="insert", .self=FakeAST(StackReference, FakeAST(Var, env->comprehension_var)),
.args=new(arg_ast_t, .value=comp->expr));
if (comp->filter)
body = WrapAST(body, If, .condition=comp->filter, .body=body);
- ast_t *loop = WrapAST(ast, For, .index=comp->key, .value=comp->value, .iter=comp->iter, .body=body);
+ ast_t *loop = WrapAST(ast, For, .vars=comp->vars, .iter=comp->iter, .body=body);
return compile_statement(env, loop);
}
}
@@ -1038,7 +1089,7 @@ env_t *with_enum_scope(env_t *env, type_t *t)
return env;
}
-static CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t *call_args)
+CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t *call_args)
{
table_t used_args = {};
CORD code = CORD_EMPTY;
@@ -1976,7 +2027,7 @@ CORD compile(env_t *env, ast_t *ast)
ast_t *item = FakeAST(Var, "$iter_value");
set_binding(scope, "$iter_value", new(binding_t, .type=t, .code="iter_value"));
ast_t *body = FakeAST(InlineCCode, CORD_all("reduction = $$i == 1 ? iter_value : ", compile(scope, reduction->combination), ";"));
- ast_t *loop = FakeAST(For, .index=i, .value=item, .iter=reduction->iter, .body=body, .empty=empty);
+ ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=i, .next=new(ast_list_t, .ast=item)), .iter=reduction->iter, .body=body, .empty=empty);
code = CORD_all(code, compile_statement(scope, loop), "\nreduction;})");
return code;
}