aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2024-07-20 16:45:13 -0400
committerBruce Hill <bruce@bruce-hill.com>2024-07-20 16:45:13 -0400
commitfb95bbb1d49dab882e5b4a962b7dd9b2438fdacb (patch)
treef69201ed2d604b2281f3919a2a56fc470f6481df
parent907122a049572f02880713620e0e6b024a5cff7f (diff)
Deprecate array:pairs() and switch iterator functions to use enums
-rw-r--r--builtins/array.c43
-rw-r--r--builtins/array.h1
-rw-r--r--compile.c43
-rw-r--r--environment.c37
-rw-r--r--test/arrays.tm13
-rw-r--r--typecheck.c44
6 files changed, 67 insertions, 114 deletions
diff --git a/builtins/array.c b/builtins/array.c
index 7086ff7f..2d45591c 100644
--- a/builtins/array.c
+++ b/builtins/array.c
@@ -320,49 +320,6 @@ public array_t Array$reversed(array_t array)
return reversed;
}
-typedef struct {
- array_t arr;
- int64_t i, j, item_size;
- bool self_pairs:1, ordered:1;
-} pair_info_t;
-
-static bool next_pair(void *x, void *y, pair_info_t *info)
-{
- if (info->i > info->arr.length || info->j > info->arr.length)
- return false;
-
- memcpy(x, info->arr.data + info->arr.stride * (info->i-1), info->item_size);
- memcpy(y, info->arr.data + info->arr.stride * (info->j-1), info->item_size);
- info->j += 1;
- if (!info->self_pairs && info->j == info->i)
- info->j += 1;
-
- if (info->j > info->arr.length) {
- info->i += 1;
- if (info->ordered)
- info->j = 1;
- else if (info->self_pairs)
- info->j = info->i;
- else
- info->j = info->i + 1;
- }
- return true;
-}
-
-public closure_t Array$pairs(array_t arr, bool self_pairs, bool ordered, const TypeInfo *type)
-{
- return (closure_t){
- .fn=next_pair,
- .userdata=new(pair_info_t,
- .arr=arr,
- .i=1,
- .j=self_pairs ? 1 : 2,
- .item_size=get_item_size(type),
- .self_pairs=self_pairs,
- .ordered=ordered),
- };
-}
-
public array_t Array$concat(array_t x, array_t y, const TypeInfo *type)
{
int64_t item_size = get_item_size(type);
diff --git a/builtins/array.h b/builtins/array.h
index 1b7fa4cb..56794a22 100644
--- a/builtins/array.h
+++ b/builtins/array.h
@@ -69,7 +69,6 @@ array_t Array$to(array_t *array, int64_t last);
array_t Array$by(array_t *array, int64_t stride);
array_t Array$reversed(array_t array);
array_t Array$concat(array_t x, array_t y, const TypeInfo *type);
-closure_t Array$pairs(array_t x, bool self_pairs, bool ordered, const TypeInfo *type);
uint32_t Array$hash(const array_t *arr, const TypeInfo *type);
int32_t Array$compare(const array_t *x, const array_t *y, const TypeInfo *type);
bool Array$equal(const array_t *x, const array_t *y, const TypeInfo *type);
diff --git a/compile.c b/compile.c
index 938c1c0c..6bd384d1 100644
--- a/compile.c
+++ b/compile.c
@@ -899,17 +899,14 @@ CORD compile_statement(env_t *env, ast_t *ast)
}
}
case FunctionType: case ClosureType: {
+ // Iterator function:
CORD code = "{\n";
- auto fn = iter_t->tag == ClosureType ? Match(Match(iter_t, ClosureType)->fn, FunctionType) : Match(iter_t, FunctionType);
- arg_t *next_arg = fn->args;
- for (ast_list_t *var = for_->vars; var; var = var->next) {
- const char *name = Match(var->ast, Var)->name;
- type_t *t = Match(get_arg_type(env, next_arg), PointerType)->pointed;
- code = CORD_all(code, compile_declaration(t, CORD_cat("$", name)), ";\n");
- }
code = CORD_all(code, compile_declaration(iter_t, "next"), " = ", compile(env, for_->iter), ";\n");
+ auto fn = iter_t->tag == ClosureType ? Match(Match(iter_t, ClosureType)->fn, FunctionType) : Match(iter_t, FunctionType);
+ code = CORD_all(code, compile_declaration(fn->ret, "cur"), ";\n"); // Iteration enum
+
CORD next_fn;
if (iter_t->tag == ClosureType) {
type_t *fn_t = Match(iter_t, ClosureType)->fn;
@@ -919,20 +916,14 @@ CORD compile_statement(env_t *env, ast_t *ast)
closure_fn_args = new(arg_t, .name="userdata", .type=Type(PointerType, .pointed=Type(MemoryType)), .next=closure_fn_args);
REVERSE_LIST(closure_fn_args);
CORD fn_type_code = compile_type(Type(FunctionType, .args=closure_fn_args, .ret=Match(fn_t, FunctionType)->ret));
- next_fn = CORD_all("((", fn_type_code, ")next.fn)(");
+ next_fn = CORD_all("((", fn_type_code, ")next.fn)");
} else {
- next_fn = "next(";
+ next_fn = "next";
}
- for (ast_list_t *var = for_->vars; var; var = var->next) {
- const char *name = Match(var->ast, Var)->name;
- next_fn = CORD_all(next_fn, "&$", name);
- if (var->next || iter_t->tag == ClosureType)
- next_fn = CORD_all(next_fn, ", ");
- }
- if (iter_t->tag == ClosureType)
- next_fn = CORD_all(next_fn, "next.userdata");
- next_fn = CORD_all(next_fn, ")");
+ env_t *enum_env = Match(fn->ret, EnumType)->env;
+ next_fn = CORD_all("(cur=", next_fn, iter_t->tag == ClosureType ? "(next.userdata)" : "()", ").$tag == ",
+ namespace_prefix(enum_env->libname, enum_env->namespace), "tag$Next");
if (for_->empty) {
code = CORD_all(code, "if (", next_fn, ") {\n"
@@ -1909,11 +1900,6 @@ CORD compile(env_t *env, ast_t *ast)
CORD self = compile_to_pointer_depth(env, call->self, 0, false);
(void)compile_arguments(env, ast, NULL, call->args);
return CORD_all("Array$reversed(", self, ")");
- } else if (streq(call->name, "pairs")) {
- CORD self = compile_to_pointer_depth(env, call->self, 0, false);
- arg_t *arg_spec = new(arg_t, .name="self_pairs", .default_val=FakeAST(Bool, false), .type=Type(BoolType),
- .next=new(arg_t, .name="ordered", .default_val=FakeAST(Bool, false), .type=Type(BoolType)));
- return CORD_all("Array$pairs(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ", compile_type_info(env, self_value_t), ")");
} else code_err(ast, "There is no '%s' method for arrays", call->name);
}
case TableType: {
@@ -2059,9 +2045,16 @@ CORD compile(env_t *env, ast_t *ast)
(long)(reduction->iter->end - reduction->iter->file->text)));
}
ast_t *item = FakeAST(Var, "$iter_value");
- set_binding(scope, "$iter_value", new(binding_t, .type=t, .code="$$iter_value"));
- ast_t *body = FakeAST(InlineCCode, CORD_all("if (is_first) {\nreduction = $$iter_value;\nis_first = no;\n} else {\nreduction = ", compile(scope, reduction->combination), ";\n}\n"));
+ ast_t *body = FakeAST(InlineCCode, .code="{}"); // placeholder
ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body, .empty=empty);
+ env_t *body_scope = for_scope(scope, loop);
+ body->__data.InlineCCode.code = CORD_all(
+ "if (is_first) {\n"
+ " reduction = ", compile(body_scope, item), ";\n"
+ " is_first = no;\n"
+ "} else {\n"
+ " reduction = ", compile(body_scope, reduction->combination), ";\n"
+ "}\n");
code = CORD_all(code, compile_statement(scope, loop), "\nreduction;})");
return code;
}
diff --git a/environment.c b/environment.c
index 1fe698d2..a5d935f3 100644
--- a/environment.c
+++ b/environment.c
@@ -334,22 +334,33 @@ env_t *for_scope(env_t *env, ast_t *ast)
}
case FunctionType: case ClosureType: {
auto fn = iter_t->tag == ClosureType ? Match(Match(iter_t, ClosureType)->fn, FunctionType) : Match(iter_t, FunctionType);
- arg_t *next_arg = fn->args;
+ if (fn->ret->tag != EnumType)
+ code_err(for_->iter, "Iterator functions must return an enum with a Done and Next field");
+ auto iter_enum = Match(fn->ret, EnumType);
+ type_t *next_type = NULL;
+ for (tag_t *tag = iter_enum->tags; tag; tag = tag->next) {
+ if (streq(tag->name, "Done")) {
+ if (Match(tag->type, StructType)->fields)
+ code_err(for_->iter, "This iterator function returns an enum with a Done field that has values, when none are allowed");
+ } else if (streq(tag->name, "Next")) {
+ next_type = tag->type;
+ } else {
+ code_err(for_->iter, "This iterator function returns an enum with a value that isn't Done or Next: %s", tag->name);
+ }
+ }
+
+ if (!next_type)
+ code_err(for_->iter, "This iterator function returns an enum that doesn't have a Next field");
+
+ arg_t *iter_field = Match(next_type, StructType)->fields;
for (ast_list_t *var = for_->vars; var; var = var->next) {
- if (next_arg == NULL)
- code_err(var->ast, "This is too many variables for this iterator function");
+ if (!iter_field)
+ code_err(var->ast, "This is one variable too many for this iterator, which returns a %T", fn->ret);
const char *name = Match(var->ast, Var)->name;
- type_t *t = get_arg_type(env, next_arg);
- if (t->tag != PointerType)
- code_err(for_->iter, "This iterator has type %T, but I need all its arguments to be mutable stack pointers", iter_t);
- auto ptr = Match(t, PointerType);
- if (!ptr->is_stack || ptr->is_readonly)
- code_err(for_->iter, "This iterator has type %T, but I need all its arguments to be mutable stack pointers", iter_t);
- set_binding(scope, name, new(binding_t, .type=ptr->pointed, .code=CORD_cat("$", name)));
- next_arg = next_arg->next;
+ type_t *t = get_arg_type(env, iter_field);
+ set_binding(scope, name, new(binding_t, .type=t, .code=CORD_cat("cur.Next.", iter_field->name)));
+ iter_field = iter_field->next;
}
- if (next_arg)
- code_err(ast, "There are not enough variables given for this loop with an iterator that has type %T", iter_t);
return scope;
}
default: code_err(for_->iter, "Iteration is not implemented for type: %T", iter_t);
diff --git a/test/arrays.tm b/test/arrays.tm
index ed43acb0..fb6d16ca 100644
--- a/test/arrays.tm
+++ b/test/arrays.tm
@@ -143,16 +143,3 @@ func main():
>> [i*10 for i in 10]:by(2):by(-1)
= [90, 70, 50, 30, 10]
-
- do:
- strs := ["A", "B", "C"]
- >> ["{x}{y}" for x, y in strs:pairs()]
- = ["AB", "AC", "BC"]
- >> ["{x}{y}" for x, y in strs:pairs(self_pairs=yes)]
- = ["AA", "AB", "AC", "BB", "BC", "CC"]
- >> ["{x}{y}" for x, y in strs:pairs(ordered=yes)]
- = ["AB", "AC", "BA", "BC", "CA", "CB"]
- >> ["{x}{y}" for x, y in strs:pairs(self_pairs=yes, ordered=yes)]
- = ["AA", "AB", "AC", "BA", "BB", "BC", "CA", "CB", "CC"]
- >> ["!" for x,y in [:Text]:pairs()]
- = []
diff --git a/typecheck.c b/typecheck.c
index e61873d1..0dc7cdc0 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -1,9 +1,9 @@
// Logic for getting a type from an AST node
-#include <gc.h>
#include <ctype.h>
+#include <gc.h>
+#include <signal.h>
#include <stdarg.h>
#include <stdlib.h>
-#include <signal.h>
#include <string.h>
#include <sys/stat.h>
@@ -678,11 +678,7 @@ type_t *get_type(env_t *env, ast_t *ast)
else if (streq(call->name, "heapify")) return Type(VoidType);
else if (streq(call->name, "heap_push")) return Type(VoidType);
else if (streq(call->name, "heap_pop")) return Match(self_value_t, ArrayType)->item_type;
- else if (streq(call->name, "pairs")) {
- type_t *ref_t = Type(PointerType, .pointed=Match(self_value_t, ArrayType)->item_type, .is_stack=true);
- arg_t *args = new(arg_t, .name="x", .type=ref_t, .next=new(arg_t, .name="y", .type=ref_t));
- return Type(ClosureType, .fn=Type(FunctionType, .args=args, .ret=Type(BoolType)));
- } else code_err(ast, "There is no '%s' method for arrays", call->name);
+ else code_err(ast, "There is no '%s' method for arrays", call->name);
}
case TableType: {
auto table = Match(self_value_t, TableType);
@@ -740,7 +736,7 @@ type_t *get_type(env_t *env, ast_t *ast)
case Return: {
ast_t *val = Match(ast, Return)->value;
// Support unqualified enum return values:
- if (env->fn_ctx && env->fn_ctx->return_type->tag == EnumType) {
+ if (env->fn_ctx && env->fn_ctx->return_type && env->fn_ctx->return_type->tag == EnumType) {
env = fresh_scope(env);
auto enum_ = Match(env->fn_ctx->return_type, EnumType);
env_t *ns_env = enum_->env;
@@ -903,25 +899,35 @@ type_t *get_type(env_t *env, ast_t *ast)
case ArrayType: value_t = Match(iter_value_t, ArrayType)->item_type; break;
case TableType: value_t = Match(iter_value_t, TableType)->key_type; break;
case FunctionType: case ClosureType: {
+ // Iterator function
auto fn = iter_value_t->tag == ClosureType ?
Match(Match(iter_value_t, ClosureType)->fn, FunctionType) : Match(iter_value_t, FunctionType);
- if (!fn->args || fn->args->next)
- code_err(reduction->iter, "I expected this iterable to have exactly one argument, not %T", iter_value_t);
- type_t *arg_type = get_arg_type(env, fn->args);
- if (arg_type->tag != PointerType)
- code_err(reduction->iter, "I expected this iterable to have exactly one stack reference argument, not %T", arg_type);
- auto ptr = Match(arg_type, PointerType);
- if (!ptr->is_stack || ptr->is_optional || ptr->is_readonly)
- code_err(reduction->iter, "I expected this iterable to have exactly one stack reference argument, not %T", arg_type);
- value_t = ptr->pointed;
+ if (fn->args)
+ code_err(reduction->iter, "I expected this iterator function to not take any arguments, but it's %T", iter_value_t);
+ if (fn->ret->tag != EnumType)
+ code_err(reduction->iter, "I expected this iterator function to return an enum, but it's %T", iter_value_t);
+ value_t = NULL;
+ for (tag_t *tag = Match(fn->ret, EnumType)->tags; tag; tag = tag->next) {
+ if (streq(tag->name, "Next")) {
+ arg_t *fields = Match(tag->type, StructType)->fields;
+ if (!fields || fields->next)
+ code_err(reduction->iter,
+ "I expected this iterator function to return an enum with a Next() that has exactly one value, not %T",
+ tag->type);
+ value_t = fields->type;
+ break;
+ }
+ }
+ if (!value_t)
+ code_err(reduction->iter, "This iterator function doesn't return an enum with a Next() value");
break;
}
default: code_err(reduction->iter, "I don't know how to do a reduction over %T values", iter_t);
}
env_t *scope = fresh_scope(env);
- set_binding(scope, "$reduction", new(binding_t, .type=value_t, .code="reduction"));
- set_binding(scope, "$iter_value", new(binding_t, .type=value_t, .code="iter_value"));
+ set_binding(scope, "$reduction", new(binding_t, .type=value_t));
+ set_binding(scope, "$iter_value", new(binding_t, .type=value_t));
type_t *t = get_type(scope, reduction->combination);
if (!reduction->fallback)
return t;