From fb95bbb1d49dab882e5b4a962b7dd9b2438fdacb Mon Sep 17 00:00:00 2001 From: Bruce Hill Date: Sat, 20 Jul 2024 16:45:13 -0400 Subject: [PATCH] Deprecate array:pairs() and switch iterator functions to use enums --- builtins/array.c | 43 ------------------------------------------- builtins/array.h | 1 - compile.c | 43 ++++++++++++++++++------------------------- environment.c | 41 ++++++++++++++++++++++++++--------------- test/arrays.tm | 13 ------------- typecheck.c | 44 +++++++++++++++++++++++++------------------- 6 files changed, 69 insertions(+), 116 deletions(-) diff --git a/builtins/array.c b/builtins/array.c index 7086ff7..2d45591 100644 --- a/builtins/array.c +++ b/builtins/array.c @@ -320,49 +320,6 @@ public array_t Array$reversed(array_t array) return reversed; } -typedef struct { - array_t arr; - int64_t i, j, item_size; - bool self_pairs:1, ordered:1; -} pair_info_t; - -static bool next_pair(void *x, void *y, pair_info_t *info) -{ - if (info->i > info->arr.length || info->j > info->arr.length) - return false; - - memcpy(x, info->arr.data + info->arr.stride * (info->i-1), info->item_size); - memcpy(y, info->arr.data + info->arr.stride * (info->j-1), info->item_size); - info->j += 1; - if (!info->self_pairs && info->j == info->i) - info->j += 1; - - if (info->j > info->arr.length) { - info->i += 1; - if (info->ordered) - info->j = 1; - else if (info->self_pairs) - info->j = info->i; - else - info->j = info->i + 1; - } - return true; -} - -public closure_t Array$pairs(array_t arr, bool self_pairs, bool ordered, const TypeInfo *type) -{ - return (closure_t){ - .fn=next_pair, - .userdata=new(pair_info_t, - .arr=arr, - .i=1, - .j=self_pairs ? 1 : 2, - .item_size=get_item_size(type), - .self_pairs=self_pairs, - .ordered=ordered), - }; -} - public array_t Array$concat(array_t x, array_t y, const TypeInfo *type) { int64_t item_size = get_item_size(type); diff --git a/builtins/array.h b/builtins/array.h index 1b7fa4c..56794a2 100644 --- a/builtins/array.h +++ b/builtins/array.h @@ -69,7 +69,6 @@ array_t Array$to(array_t *array, int64_t last); array_t Array$by(array_t *array, int64_t stride); array_t Array$reversed(array_t array); array_t Array$concat(array_t x, array_t y, const TypeInfo *type); -closure_t Array$pairs(array_t x, bool self_pairs, bool ordered, const TypeInfo *type); uint32_t Array$hash(const array_t *arr, const TypeInfo *type); int32_t Array$compare(const array_t *x, const array_t *y, const TypeInfo *type); bool Array$equal(const array_t *x, const array_t *y, const TypeInfo *type); diff --git a/compile.c b/compile.c index 938c1c0..6bd384d 100644 --- a/compile.c +++ b/compile.c @@ -899,17 +899,14 @@ CORD compile_statement(env_t *env, ast_t *ast) } } case FunctionType: case ClosureType: { + // Iterator function: CORD code = "{\n"; - auto fn = iter_t->tag == ClosureType ? Match(Match(iter_t, ClosureType)->fn, FunctionType) : Match(iter_t, FunctionType); - arg_t *next_arg = fn->args; - for (ast_list_t *var = for_->vars; var; var = var->next) { - const char *name = Match(var->ast, Var)->name; - type_t *t = Match(get_arg_type(env, next_arg), PointerType)->pointed; - code = CORD_all(code, compile_declaration(t, CORD_cat("$", name)), ";\n"); - } code = CORD_all(code, compile_declaration(iter_t, "next"), " = ", compile(env, for_->iter), ";\n"); + auto fn = iter_t->tag == ClosureType ? Match(Match(iter_t, ClosureType)->fn, FunctionType) : Match(iter_t, FunctionType); + code = CORD_all(code, compile_declaration(fn->ret, "cur"), ";\n"); // Iteration enum + CORD next_fn; if (iter_t->tag == ClosureType) { type_t *fn_t = Match(iter_t, ClosureType)->fn; @@ -919,20 +916,14 @@ CORD compile_statement(env_t *env, ast_t *ast) closure_fn_args = new(arg_t, .name="userdata", .type=Type(PointerType, .pointed=Type(MemoryType)), .next=closure_fn_args); REVERSE_LIST(closure_fn_args); CORD fn_type_code = compile_type(Type(FunctionType, .args=closure_fn_args, .ret=Match(fn_t, FunctionType)->ret)); - next_fn = CORD_all("((", fn_type_code, ")next.fn)("); + next_fn = CORD_all("((", fn_type_code, ")next.fn)"); } else { - next_fn = "next("; + next_fn = "next"; } - for (ast_list_t *var = for_->vars; var; var = var->next) { - const char *name = Match(var->ast, Var)->name; - next_fn = CORD_all(next_fn, "&$", name); - if (var->next || iter_t->tag == ClosureType) - next_fn = CORD_all(next_fn, ", "); - } - if (iter_t->tag == ClosureType) - next_fn = CORD_all(next_fn, "next.userdata"); - next_fn = CORD_all(next_fn, ")"); + env_t *enum_env = Match(fn->ret, EnumType)->env; + next_fn = CORD_all("(cur=", next_fn, iter_t->tag == ClosureType ? "(next.userdata)" : "()", ").$tag == ", + namespace_prefix(enum_env->libname, enum_env->namespace), "tag$Next"); if (for_->empty) { code = CORD_all(code, "if (", next_fn, ") {\n" @@ -1909,11 +1900,6 @@ CORD compile(env_t *env, ast_t *ast) CORD self = compile_to_pointer_depth(env, call->self, 0, false); (void)compile_arguments(env, ast, NULL, call->args); return CORD_all("Array$reversed(", self, ")"); - } else if (streq(call->name, "pairs")) { - CORD self = compile_to_pointer_depth(env, call->self, 0, false); - arg_t *arg_spec = new(arg_t, .name="self_pairs", .default_val=FakeAST(Bool, false), .type=Type(BoolType), - .next=new(arg_t, .name="ordered", .default_val=FakeAST(Bool, false), .type=Type(BoolType))); - return CORD_all("Array$pairs(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ", compile_type_info(env, self_value_t), ")"); } else code_err(ast, "There is no '%s' method for arrays", call->name); } case TableType: { @@ -2059,9 +2045,16 @@ CORD compile(env_t *env, ast_t *ast) (long)(reduction->iter->end - reduction->iter->file->text))); } ast_t *item = FakeAST(Var, "$iter_value"); - set_binding(scope, "$iter_value", new(binding_t, .type=t, .code="$$iter_value")); - ast_t *body = FakeAST(InlineCCode, CORD_all("if (is_first) {\nreduction = $$iter_value;\nis_first = no;\n} else {\nreduction = ", compile(scope, reduction->combination), ";\n}\n")); + ast_t *body = FakeAST(InlineCCode, .code="{}"); // placeholder ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body, .empty=empty); + env_t *body_scope = for_scope(scope, loop); + body->__data.InlineCCode.code = CORD_all( + "if (is_first) {\n" + " reduction = ", compile(body_scope, item), ";\n" + " is_first = no;\n" + "} else {\n" + " reduction = ", compile(body_scope, reduction->combination), ";\n" + "}\n"); code = CORD_all(code, compile_statement(scope, loop), "\nreduction;})"); return code; } diff --git a/environment.c b/environment.c index 1fe698d..a5d935f 100644 --- a/environment.c +++ b/environment.c @@ -334,22 +334,33 @@ env_t *for_scope(env_t *env, ast_t *ast) } case FunctionType: case ClosureType: { auto fn = iter_t->tag == ClosureType ? Match(Match(iter_t, ClosureType)->fn, FunctionType) : Match(iter_t, FunctionType); - arg_t *next_arg = fn->args; - for (ast_list_t *var = for_->vars; var; var = var->next) { - if (next_arg == NULL) - code_err(var->ast, "This is too many variables for this iterator function"); - const char *name = Match(var->ast, Var)->name; - type_t *t = get_arg_type(env, next_arg); - if (t->tag != PointerType) - code_err(for_->iter, "This iterator has type %T, but I need all its arguments to be mutable stack pointers", iter_t); - auto ptr = Match(t, PointerType); - if (!ptr->is_stack || ptr->is_readonly) - code_err(for_->iter, "This iterator has type %T, but I need all its arguments to be mutable stack pointers", iter_t); - set_binding(scope, name, new(binding_t, .type=ptr->pointed, .code=CORD_cat("$", name))); - next_arg = next_arg->next; + if (fn->ret->tag != EnumType) + code_err(for_->iter, "Iterator functions must return an enum with a Done and Next field"); + auto iter_enum = Match(fn->ret, EnumType); + type_t *next_type = NULL; + for (tag_t *tag = iter_enum->tags; tag; tag = tag->next) { + if (streq(tag->name, "Done")) { + if (Match(tag->type, StructType)->fields) + code_err(for_->iter, "This iterator function returns an enum with a Done field that has values, when none are allowed"); + } else if (streq(tag->name, "Next")) { + next_type = tag->type; + } else { + code_err(for_->iter, "This iterator function returns an enum with a value that isn't Done or Next: %s", tag->name); + } + } + + if (!next_type) + code_err(for_->iter, "This iterator function returns an enum that doesn't have a Next field"); + + arg_t *iter_field = Match(next_type, StructType)->fields; + for (ast_list_t *var = for_->vars; var; var = var->next) { + if (!iter_field) + code_err(var->ast, "This is one variable too many for this iterator, which returns a %T", fn->ret); + const char *name = Match(var->ast, Var)->name; + type_t *t = get_arg_type(env, iter_field); + set_binding(scope, name, new(binding_t, .type=t, .code=CORD_cat("cur.Next.", iter_field->name))); + iter_field = iter_field->next; } - if (next_arg) - code_err(ast, "There are not enough variables given for this loop with an iterator that has type %T", iter_t); return scope; } default: code_err(for_->iter, "Iteration is not implemented for type: %T", iter_t); diff --git a/test/arrays.tm b/test/arrays.tm index ed43acb..fb6d16c 100644 --- a/test/arrays.tm +++ b/test/arrays.tm @@ -143,16 +143,3 @@ func main(): >> [i*10 for i in 10]:by(2):by(-1) = [90, 70, 50, 30, 10] - - do: - strs := ["A", "B", "C"] - >> ["{x}{y}" for x, y in strs:pairs()] - = ["AB", "AC", "BC"] - >> ["{x}{y}" for x, y in strs:pairs(self_pairs=yes)] - = ["AA", "AB", "AC", "BB", "BC", "CC"] - >> ["{x}{y}" for x, y in strs:pairs(ordered=yes)] - = ["AB", "AC", "BA", "BC", "CA", "CB"] - >> ["{x}{y}" for x, y in strs:pairs(self_pairs=yes, ordered=yes)] - = ["AA", "AB", "AC", "BA", "BB", "BC", "CA", "CB", "CC"] - >> ["!" for x,y in [:Text]:pairs()] - = [] diff --git a/typecheck.c b/typecheck.c index e61873d..0dc7cdc 100644 --- a/typecheck.c +++ b/typecheck.c @@ -1,9 +1,9 @@ // Logic for getting a type from an AST node -#include #include +#include +#include #include #include -#include #include #include @@ -678,11 +678,7 @@ type_t *get_type(env_t *env, ast_t *ast) else if (streq(call->name, "heapify")) return Type(VoidType); else if (streq(call->name, "heap_push")) return Type(VoidType); else if (streq(call->name, "heap_pop")) return Match(self_value_t, ArrayType)->item_type; - else if (streq(call->name, "pairs")) { - type_t *ref_t = Type(PointerType, .pointed=Match(self_value_t, ArrayType)->item_type, .is_stack=true); - arg_t *args = new(arg_t, .name="x", .type=ref_t, .next=new(arg_t, .name="y", .type=ref_t)); - return Type(ClosureType, .fn=Type(FunctionType, .args=args, .ret=Type(BoolType))); - } else code_err(ast, "There is no '%s' method for arrays", call->name); + else code_err(ast, "There is no '%s' method for arrays", call->name); } case TableType: { auto table = Match(self_value_t, TableType); @@ -740,7 +736,7 @@ type_t *get_type(env_t *env, ast_t *ast) case Return: { ast_t *val = Match(ast, Return)->value; // Support unqualified enum return values: - if (env->fn_ctx && env->fn_ctx->return_type->tag == EnumType) { + if (env->fn_ctx && env->fn_ctx->return_type && env->fn_ctx->return_type->tag == EnumType) { env = fresh_scope(env); auto enum_ = Match(env->fn_ctx->return_type, EnumType); env_t *ns_env = enum_->env; @@ -903,25 +899,35 @@ type_t *get_type(env_t *env, ast_t *ast) case ArrayType: value_t = Match(iter_value_t, ArrayType)->item_type; break; case TableType: value_t = Match(iter_value_t, TableType)->key_type; break; case FunctionType: case ClosureType: { + // Iterator function auto fn = iter_value_t->tag == ClosureType ? Match(Match(iter_value_t, ClosureType)->fn, FunctionType) : Match(iter_value_t, FunctionType); - if (!fn->args || fn->args->next) - code_err(reduction->iter, "I expected this iterable to have exactly one argument, not %T", iter_value_t); - type_t *arg_type = get_arg_type(env, fn->args); - if (arg_type->tag != PointerType) - code_err(reduction->iter, "I expected this iterable to have exactly one stack reference argument, not %T", arg_type); - auto ptr = Match(arg_type, PointerType); - if (!ptr->is_stack || ptr->is_optional || ptr->is_readonly) - code_err(reduction->iter, "I expected this iterable to have exactly one stack reference argument, not %T", arg_type); - value_t = ptr->pointed; + if (fn->args) + code_err(reduction->iter, "I expected this iterator function to not take any arguments, but it's %T", iter_value_t); + if (fn->ret->tag != EnumType) + code_err(reduction->iter, "I expected this iterator function to return an enum, but it's %T", iter_value_t); + value_t = NULL; + for (tag_t *tag = Match(fn->ret, EnumType)->tags; tag; tag = tag->next) { + if (streq(tag->name, "Next")) { + arg_t *fields = Match(tag->type, StructType)->fields; + if (!fields || fields->next) + code_err(reduction->iter, + "I expected this iterator function to return an enum with a Next() that has exactly one value, not %T", + tag->type); + value_t = fields->type; + break; + } + } + if (!value_t) + code_err(reduction->iter, "This iterator function doesn't return an enum with a Next() value"); break; } default: code_err(reduction->iter, "I don't know how to do a reduction over %T values", iter_t); } env_t *scope = fresh_scope(env); - set_binding(scope, "$reduction", new(binding_t, .type=value_t, .code="reduction")); - set_binding(scope, "$iter_value", new(binding_t, .type=value_t, .code="iter_value")); + set_binding(scope, "$reduction", new(binding_t, .type=value_t)); + set_binding(scope, "$iter_value", new(binding_t, .type=value_t)); type_t *t = get_type(scope, reduction->combination); if (!reduction->fallback) return t;