Deprecate array:pairs() and switch iterator functions to use enums

This commit is contained in:
Bruce Hill 2024-07-20 16:45:13 -04:00
parent 907122a049
commit fb95bbb1d4
6 changed files with 69 additions and 116 deletions

View File

@ -320,49 +320,6 @@ public array_t Array$reversed(array_t array)
return reversed;
}
typedef struct {
array_t arr;
int64_t i, j, item_size;
bool self_pairs:1, ordered:1;
} pair_info_t;
static bool next_pair(void *x, void *y, pair_info_t *info)
{
if (info->i > info->arr.length || info->j > info->arr.length)
return false;
memcpy(x, info->arr.data + info->arr.stride * (info->i-1), info->item_size);
memcpy(y, info->arr.data + info->arr.stride * (info->j-1), info->item_size);
info->j += 1;
if (!info->self_pairs && info->j == info->i)
info->j += 1;
if (info->j > info->arr.length) {
info->i += 1;
if (info->ordered)
info->j = 1;
else if (info->self_pairs)
info->j = info->i;
else
info->j = info->i + 1;
}
return true;
}
public closure_t Array$pairs(array_t arr, bool self_pairs, bool ordered, const TypeInfo *type)
{
return (closure_t){
.fn=next_pair,
.userdata=new(pair_info_t,
.arr=arr,
.i=1,
.j=self_pairs ? 1 : 2,
.item_size=get_item_size(type),
.self_pairs=self_pairs,
.ordered=ordered),
};
}
public array_t Array$concat(array_t x, array_t y, const TypeInfo *type)
{
int64_t item_size = get_item_size(type);

View File

@ -69,7 +69,6 @@ array_t Array$to(array_t *array, int64_t last);
array_t Array$by(array_t *array, int64_t stride);
array_t Array$reversed(array_t array);
array_t Array$concat(array_t x, array_t y, const TypeInfo *type);
closure_t Array$pairs(array_t x, bool self_pairs, bool ordered, const TypeInfo *type);
uint32_t Array$hash(const array_t *arr, const TypeInfo *type);
int32_t Array$compare(const array_t *x, const array_t *y, const TypeInfo *type);
bool Array$equal(const array_t *x, const array_t *y, const TypeInfo *type);

View File

@ -899,17 +899,14 @@ CORD compile_statement(env_t *env, ast_t *ast)
}
}
case FunctionType: case ClosureType: {
// Iterator function:
CORD code = "{\n";
auto fn = iter_t->tag == ClosureType ? Match(Match(iter_t, ClosureType)->fn, FunctionType) : Match(iter_t, FunctionType);
arg_t *next_arg = fn->args;
for (ast_list_t *var = for_->vars; var; var = var->next) {
const char *name = Match(var->ast, Var)->name;
type_t *t = Match(get_arg_type(env, next_arg), PointerType)->pointed;
code = CORD_all(code, compile_declaration(t, CORD_cat("$", name)), ";\n");
}
code = CORD_all(code, compile_declaration(iter_t, "next"), " = ", compile(env, for_->iter), ";\n");
auto fn = iter_t->tag == ClosureType ? Match(Match(iter_t, ClosureType)->fn, FunctionType) : Match(iter_t, FunctionType);
code = CORD_all(code, compile_declaration(fn->ret, "cur"), ";\n"); // Iteration enum
CORD next_fn;
if (iter_t->tag == ClosureType) {
type_t *fn_t = Match(iter_t, ClosureType)->fn;
@ -919,20 +916,14 @@ CORD compile_statement(env_t *env, ast_t *ast)
closure_fn_args = new(arg_t, .name="userdata", .type=Type(PointerType, .pointed=Type(MemoryType)), .next=closure_fn_args);
REVERSE_LIST(closure_fn_args);
CORD fn_type_code = compile_type(Type(FunctionType, .args=closure_fn_args, .ret=Match(fn_t, FunctionType)->ret));
next_fn = CORD_all("((", fn_type_code, ")next.fn)(");
next_fn = CORD_all("((", fn_type_code, ")next.fn)");
} else {
next_fn = "next(";
next_fn = "next";
}
for (ast_list_t *var = for_->vars; var; var = var->next) {
const char *name = Match(var->ast, Var)->name;
next_fn = CORD_all(next_fn, "&$", name);
if (var->next || iter_t->tag == ClosureType)
next_fn = CORD_all(next_fn, ", ");
}
if (iter_t->tag == ClosureType)
next_fn = CORD_all(next_fn, "next.userdata");
next_fn = CORD_all(next_fn, ")");
env_t *enum_env = Match(fn->ret, EnumType)->env;
next_fn = CORD_all("(cur=", next_fn, iter_t->tag == ClosureType ? "(next.userdata)" : "()", ").$tag == ",
namespace_prefix(enum_env->libname, enum_env->namespace), "tag$Next");
if (for_->empty) {
code = CORD_all(code, "if (", next_fn, ") {\n"
@ -1909,11 +1900,6 @@ CORD compile(env_t *env, ast_t *ast)
CORD self = compile_to_pointer_depth(env, call->self, 0, false);
(void)compile_arguments(env, ast, NULL, call->args);
return CORD_all("Array$reversed(", self, ")");
} else if (streq(call->name, "pairs")) {
CORD self = compile_to_pointer_depth(env, call->self, 0, false);
arg_t *arg_spec = new(arg_t, .name="self_pairs", .default_val=FakeAST(Bool, false), .type=Type(BoolType),
.next=new(arg_t, .name="ordered", .default_val=FakeAST(Bool, false), .type=Type(BoolType)));
return CORD_all("Array$pairs(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ", compile_type_info(env, self_value_t), ")");
} else code_err(ast, "There is no '%s' method for arrays", call->name);
}
case TableType: {
@ -2059,9 +2045,16 @@ CORD compile(env_t *env, ast_t *ast)
(long)(reduction->iter->end - reduction->iter->file->text)));
}
ast_t *item = FakeAST(Var, "$iter_value");
set_binding(scope, "$iter_value", new(binding_t, .type=t, .code="$$iter_value"));
ast_t *body = FakeAST(InlineCCode, CORD_all("if (is_first) {\nreduction = $$iter_value;\nis_first = no;\n} else {\nreduction = ", compile(scope, reduction->combination), ";\n}\n"));
ast_t *body = FakeAST(InlineCCode, .code="{}"); // placeholder
ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body, .empty=empty);
env_t *body_scope = for_scope(scope, loop);
body->__data.InlineCCode.code = CORD_all(
"if (is_first) {\n"
" reduction = ", compile(body_scope, item), ";\n"
" is_first = no;\n"
"} else {\n"
" reduction = ", compile(body_scope, reduction->combination), ";\n"
"}\n");
code = CORD_all(code, compile_statement(scope, loop), "\nreduction;})");
return code;
}

View File

@ -334,22 +334,33 @@ env_t *for_scope(env_t *env, ast_t *ast)
}
case FunctionType: case ClosureType: {
auto fn = iter_t->tag == ClosureType ? Match(Match(iter_t, ClosureType)->fn, FunctionType) : Match(iter_t, FunctionType);
arg_t *next_arg = fn->args;
for (ast_list_t *var = for_->vars; var; var = var->next) {
if (next_arg == NULL)
code_err(var->ast, "This is too many variables for this iterator function");
const char *name = Match(var->ast, Var)->name;
type_t *t = get_arg_type(env, next_arg);
if (t->tag != PointerType)
code_err(for_->iter, "This iterator has type %T, but I need all its arguments to be mutable stack pointers", iter_t);
auto ptr = Match(t, PointerType);
if (!ptr->is_stack || ptr->is_readonly)
code_err(for_->iter, "This iterator has type %T, but I need all its arguments to be mutable stack pointers", iter_t);
set_binding(scope, name, new(binding_t, .type=ptr->pointed, .code=CORD_cat("$", name)));
next_arg = next_arg->next;
if (fn->ret->tag != EnumType)
code_err(for_->iter, "Iterator functions must return an enum with a Done and Next field");
auto iter_enum = Match(fn->ret, EnumType);
type_t *next_type = NULL;
for (tag_t *tag = iter_enum->tags; tag; tag = tag->next) {
if (streq(tag->name, "Done")) {
if (Match(tag->type, StructType)->fields)
code_err(for_->iter, "This iterator function returns an enum with a Done field that has values, when none are allowed");
} else if (streq(tag->name, "Next")) {
next_type = tag->type;
} else {
code_err(for_->iter, "This iterator function returns an enum with a value that isn't Done or Next: %s", tag->name);
}
}
if (!next_type)
code_err(for_->iter, "This iterator function returns an enum that doesn't have a Next field");
arg_t *iter_field = Match(next_type, StructType)->fields;
for (ast_list_t *var = for_->vars; var; var = var->next) {
if (!iter_field)
code_err(var->ast, "This is one variable too many for this iterator, which returns a %T", fn->ret);
const char *name = Match(var->ast, Var)->name;
type_t *t = get_arg_type(env, iter_field);
set_binding(scope, name, new(binding_t, .type=t, .code=CORD_cat("cur.Next.", iter_field->name)));
iter_field = iter_field->next;
}
if (next_arg)
code_err(ast, "There are not enough variables given for this loop with an iterator that has type %T", iter_t);
return scope;
}
default: code_err(for_->iter, "Iteration is not implemented for type: %T", iter_t);

View File

@ -143,16 +143,3 @@ func main():
>> [i*10 for i in 10]:by(2):by(-1)
= [90, 70, 50, 30, 10]
do:
strs := ["A", "B", "C"]
>> ["{x}{y}" for x, y in strs:pairs()]
= ["AB", "AC", "BC"]
>> ["{x}{y}" for x, y in strs:pairs(self_pairs=yes)]
= ["AA", "AB", "AC", "BB", "BC", "CC"]
>> ["{x}{y}" for x, y in strs:pairs(ordered=yes)]
= ["AB", "AC", "BA", "BC", "CA", "CB"]
>> ["{x}{y}" for x, y in strs:pairs(self_pairs=yes, ordered=yes)]
= ["AA", "AB", "AC", "BA", "BB", "BC", "CA", "CB", "CC"]
>> ["!" for x,y in [:Text]:pairs()]
= []

View File

@ -1,9 +1,9 @@
// Logic for getting a type from an AST node
#include <gc.h>
#include <ctype.h>
#include <gc.h>
#include <signal.h>
#include <stdarg.h>
#include <stdlib.h>
#include <signal.h>
#include <string.h>
#include <sys/stat.h>
@ -678,11 +678,7 @@ type_t *get_type(env_t *env, ast_t *ast)
else if (streq(call->name, "heapify")) return Type(VoidType);
else if (streq(call->name, "heap_push")) return Type(VoidType);
else if (streq(call->name, "heap_pop")) return Match(self_value_t, ArrayType)->item_type;
else if (streq(call->name, "pairs")) {
type_t *ref_t = Type(PointerType, .pointed=Match(self_value_t, ArrayType)->item_type, .is_stack=true);
arg_t *args = new(arg_t, .name="x", .type=ref_t, .next=new(arg_t, .name="y", .type=ref_t));
return Type(ClosureType, .fn=Type(FunctionType, .args=args, .ret=Type(BoolType)));
} else code_err(ast, "There is no '%s' method for arrays", call->name);
else code_err(ast, "There is no '%s' method for arrays", call->name);
}
case TableType: {
auto table = Match(self_value_t, TableType);
@ -740,7 +736,7 @@ type_t *get_type(env_t *env, ast_t *ast)
case Return: {
ast_t *val = Match(ast, Return)->value;
// Support unqualified enum return values:
if (env->fn_ctx && env->fn_ctx->return_type->tag == EnumType) {
if (env->fn_ctx && env->fn_ctx->return_type && env->fn_ctx->return_type->tag == EnumType) {
env = fresh_scope(env);
auto enum_ = Match(env->fn_ctx->return_type, EnumType);
env_t *ns_env = enum_->env;
@ -903,25 +899,35 @@ type_t *get_type(env_t *env, ast_t *ast)
case ArrayType: value_t = Match(iter_value_t, ArrayType)->item_type; break;
case TableType: value_t = Match(iter_value_t, TableType)->key_type; break;
case FunctionType: case ClosureType: {
// Iterator function
auto fn = iter_value_t->tag == ClosureType ?
Match(Match(iter_value_t, ClosureType)->fn, FunctionType) : Match(iter_value_t, FunctionType);
if (!fn->args || fn->args->next)
code_err(reduction->iter, "I expected this iterable to have exactly one argument, not %T", iter_value_t);
type_t *arg_type = get_arg_type(env, fn->args);
if (arg_type->tag != PointerType)
code_err(reduction->iter, "I expected this iterable to have exactly one stack reference argument, not %T", arg_type);
auto ptr = Match(arg_type, PointerType);
if (!ptr->is_stack || ptr->is_optional || ptr->is_readonly)
code_err(reduction->iter, "I expected this iterable to have exactly one stack reference argument, not %T", arg_type);
value_t = ptr->pointed;
if (fn->args)
code_err(reduction->iter, "I expected this iterator function to not take any arguments, but it's %T", iter_value_t);
if (fn->ret->tag != EnumType)
code_err(reduction->iter, "I expected this iterator function to return an enum, but it's %T", iter_value_t);
value_t = NULL;
for (tag_t *tag = Match(fn->ret, EnumType)->tags; tag; tag = tag->next) {
if (streq(tag->name, "Next")) {
arg_t *fields = Match(tag->type, StructType)->fields;
if (!fields || fields->next)
code_err(reduction->iter,
"I expected this iterator function to return an enum with a Next() that has exactly one value, not %T",
tag->type);
value_t = fields->type;
break;
}
}
if (!value_t)
code_err(reduction->iter, "This iterator function doesn't return an enum with a Next() value");
break;
}
default: code_err(reduction->iter, "I don't know how to do a reduction over %T values", iter_t);
}
env_t *scope = fresh_scope(env);
set_binding(scope, "$reduction", new(binding_t, .type=value_t, .code="reduction"));
set_binding(scope, "$iter_value", new(binding_t, .type=value_t, .code="iter_value"));
set_binding(scope, "$reduction", new(binding_t, .type=value_t));
set_binding(scope, "$iter_value", new(binding_t, .type=value_t));
type_t *t = get_type(scope, reduction->combination);
if (!reduction->fallback)
return t;