aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--builtins/array.c11
-rw-r--r--builtins/array.h1
-rw-r--r--compile.c7
-rw-r--r--test/arrays.tm11
-rw-r--r--typecheck.c10
5 files changed, 36 insertions, 4 deletions
diff --git a/builtins/array.c b/builtins/array.c
index 9c393699..bf1fe4d8 100644
--- a/builtins/array.c
+++ b/builtins/array.c
@@ -221,6 +221,17 @@ public Int_t Array$find(array_t arr, void *item, const TypeInfo *type)
return I(0);
}
+public void *Array$first(array_t arr, closure_t predicate)
+{
+ bool (*is_good)(void*, void*) = (void*)predicate.fn;
+ for (int64_t i = 0; i < arr.length; i++) {
+ if (is_good(arr.data + i*arr.stride, predicate.userdata))
+ return arr.data + i*arr.stride;
+ }
+ return NULL;
+}
+
+
public void Array$sort(array_t *arr, closure_t comparison, int64_t padded_item_size)
{
if (arr->data_refcount != 0 || (int64_t)arr->stride != padded_item_size)
diff --git a/builtins/array.h b/builtins/array.h
index 8f56d1c5..47d10fd1 100644
--- a/builtins/array.h
+++ b/builtins/array.h
@@ -69,6 +69,7 @@ void Array$remove_item(array_t *arr, void *item, Int_t max_removals, const TypeI
#define Array$remove_item_value(arr, item_expr, max, type) ({ __typeof(item_expr) item = item_expr; Array$remove_item(arr, &item, max, type); })
Int_t Array$find(array_t arr, void *item, const TypeInfo *type);
#define Array$find_value(arr, item_expr, type) ({ __typeof(item_expr) item = item_expr; Array$find(arr, &item, type); })
+void *Array$first(array_t arr, closure_t predicate);
void Array$sort(array_t *arr, closure_t comparison, int64_t padded_item_size);
array_t Array$sorted(array_t arr, closure_t comparison, int64_t padded_item_size);
void Array$shuffle(array_t *arr, int64_t padded_item_size);
diff --git a/compile.c b/compile.c
index 17d19f30..748f5541 100644
--- a/compile.c
+++ b/compile.c
@@ -2263,6 +2263,13 @@ CORD compile(env_t *env, ast_t *ast)
arg_t *arg_spec = new(arg_t, .name="item", .type=item_t);
return CORD_all("Array$find_value(", self, ", ", compile_arguments(env, ast, arg_spec, call->args),
", ", compile_type_info(env, self_value_t), ")");
+ } else if (streq(call->name, "first")) {
+ CORD self = compile_to_pointer_depth(env, call->self, 0, false);
+ type_t *item_ptr = Type(PointerType, .pointed=item_t, .is_stack=true);
+ type_t *predicate_type = Type(
+ ClosureType, .fn=Type(FunctionType, .args=new(arg_t, .name="item", .type=item_ptr), .ret=Type(BoolType)));
+ arg_t *arg_spec = new(arg_t, .name="predicate", .type=predicate_type);
+ return CORD_all("Array$first(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ")");
} else if (streq(call->name, "from")) {
CORD self = compile_to_pointer_depth(env, call->self, 0, false);
arg_t *arg_spec = new(arg_t, .name="first", .type=INT_TYPE);
diff --git a/test/arrays.tm b/test/arrays.tm
index ac243018..c693996e 100644
--- a/test/arrays.tm
+++ b/test/arrays.tm
@@ -158,3 +158,14 @@ func main():
>> nums:sort(func(a,b:&%Int): a:abs() <> b:abs())
>> [nums:binary_search(i, func(a,b:&Int): a:abs() <> b:abs()) for i in nums]
= [1, 2, 3, 4, 5]
+
+ >> [10, 20, 30]:find(20)
+ = 2
+ >> [10, 20, 30]:find(999)
+ = 0
+
+ >> [10, 20]:first(func(i:&Int): i:is_prime())
+ = !Int
+ >> [4, 5, 6]:first(func(i:&Int): i:is_prime())
+ = @%5?
+
diff --git a/typecheck.c b/typecheck.c
index 739535c4..1018f33e 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -728,19 +728,21 @@ type_t *get_type(env_t *env, ast_t *ast)
type_t *self_value_t = value_type(get_type(env, call->self));
switch (self_value_t->tag) {
case ArrayType: {
+ type_t *item_type = Match(self_value_t, ArrayType)->item_type;
if (streq(call->name, "binary_search")) return INT_TYPE;
else if (streq(call->name, "by")) return self_value_t;
else if (streq(call->name, "clear")) return Type(VoidType);
- else if (streq(call->name, "counts")) return Type(TableType, .key_type=Match(self_value_t, ArrayType)->item_type, .value_type=INT_TYPE);
+ else if (streq(call->name, "counts")) return Type(TableType, .key_type=item_type, .value_type=INT_TYPE);
else if (streq(call->name, "find")) return INT_TYPE;
+ else if (streq(call->name, "first")) return Type(PointerType, .pointed=item_type, .is_optional=true, .is_readonly=true);
else if (streq(call->name, "from")) return self_value_t;
else if (streq(call->name, "has")) return Type(BoolType);
- else if (streq(call->name, "heap_pop")) return Match(self_value_t, ArrayType)->item_type;
+ else if (streq(call->name, "heap_pop")) return item_type;
else if (streq(call->name, "heap_push")) return Type(VoidType);
else if (streq(call->name, "heapify")) return Type(VoidType);
else if (streq(call->name, "insert")) return Type(VoidType);
else if (streq(call->name, "insert_all")) return Type(VoidType);
- else if (streq(call->name, "random")) return Match(self_value_t, ArrayType)->item_type;
+ else if (streq(call->name, "random")) return item_type;
else if (streq(call->name, "remove_at")) return Type(VoidType);
else if (streq(call->name, "remove_item")) return Type(VoidType);
else if (streq(call->name, "reversed")) return self_value_t;
@@ -750,7 +752,7 @@ type_t *get_type(env_t *env, ast_t *ast)
else if (streq(call->name, "sort")) return Type(VoidType);
else if (streq(call->name, "sorted")) return self_value_t;
else if (streq(call->name, "to")) return self_value_t;
- else if (streq(call->name, "unique")) return Type(SetType, .item_type=Match(self_value_t, ArrayType)->item_type);
+ else if (streq(call->name, "unique")) return Type(SetType, .item_type=item_type);
else code_err(ast, "There is no '%s' method for arrays", call->name);
}
case SetType: {