From 7472837ee5a00bd9313e82f71f55b6f76ee7083b Mon Sep 17 00:00:00 2001 From: Bruce Hill Date: Sun, 18 Aug 2024 22:28:04 -0400 Subject: [PATCH] Add array:first(predicate:func(x:&T)->Bool)->@%T? --- builtins/array.c | 11 +++++++++++ builtins/array.h | 1 + compile.c | 7 +++++++ test/arrays.tm | 11 +++++++++++ typecheck.c | 10 ++++++---- 5 files changed, 36 insertions(+), 4 deletions(-) diff --git a/builtins/array.c b/builtins/array.c index 9c39369..bf1fe4d 100644 --- a/builtins/array.c +++ b/builtins/array.c @@ -221,6 +221,17 @@ public Int_t Array$find(array_t arr, void *item, const TypeInfo *type) return I(0); } +public void *Array$first(array_t arr, closure_t predicate) +{ + bool (*is_good)(void*, void*) = (void*)predicate.fn; + for (int64_t i = 0; i < arr.length; i++) { + if (is_good(arr.data + i*arr.stride, predicate.userdata)) + return arr.data + i*arr.stride; + } + return NULL; +} + + public void Array$sort(array_t *arr, closure_t comparison, int64_t padded_item_size) { if (arr->data_refcount != 0 || (int64_t)arr->stride != padded_item_size) diff --git a/builtins/array.h b/builtins/array.h index 8f56d1c..47d10fd 100644 --- a/builtins/array.h +++ b/builtins/array.h @@ -69,6 +69,7 @@ void Array$remove_item(array_t *arr, void *item, Int_t max_removals, const TypeI #define Array$remove_item_value(arr, item_expr, max, type) ({ __typeof(item_expr) item = item_expr; Array$remove_item(arr, &item, max, type); }) Int_t Array$find(array_t arr, void *item, const TypeInfo *type); #define Array$find_value(arr, item_expr, type) ({ __typeof(item_expr) item = item_expr; Array$find(arr, &item, type); }) +void *Array$first(array_t arr, closure_t predicate); void Array$sort(array_t *arr, closure_t comparison, int64_t padded_item_size); array_t Array$sorted(array_t arr, closure_t comparison, int64_t padded_item_size); void Array$shuffle(array_t *arr, int64_t padded_item_size); diff --git a/compile.c b/compile.c index 17d19f3..748f554 100644 --- a/compile.c +++ b/compile.c @@ -2263,6 +2263,13 @@ CORD compile(env_t *env, ast_t *ast) arg_t *arg_spec = new(arg_t, .name="item", .type=item_t); return CORD_all("Array$find_value(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ", compile_type_info(env, self_value_t), ")"); + } else if (streq(call->name, "first")) { + CORD self = compile_to_pointer_depth(env, call->self, 0, false); + type_t *item_ptr = Type(PointerType, .pointed=item_t, .is_stack=true); + type_t *predicate_type = Type( + ClosureType, .fn=Type(FunctionType, .args=new(arg_t, .name="item", .type=item_ptr), .ret=Type(BoolType))); + arg_t *arg_spec = new(arg_t, .name="predicate", .type=predicate_type); + return CORD_all("Array$first(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ")"); } else if (streq(call->name, "from")) { CORD self = compile_to_pointer_depth(env, call->self, 0, false); arg_t *arg_spec = new(arg_t, .name="first", .type=INT_TYPE); diff --git a/test/arrays.tm b/test/arrays.tm index ac24301..c693996 100644 --- a/test/arrays.tm +++ b/test/arrays.tm @@ -158,3 +158,14 @@ func main(): >> nums:sort(func(a,b:&%Int): a:abs() <> b:abs()) >> [nums:binary_search(i, func(a,b:&Int): a:abs() <> b:abs()) for i in nums] = [1, 2, 3, 4, 5] + + >> [10, 20, 30]:find(20) + = 2 + >> [10, 20, 30]:find(999) + = 0 + + >> [10, 20]:first(func(i:&Int): i:is_prime()) + = !Int + >> [4, 5, 6]:first(func(i:&Int): i:is_prime()) + = @%5? + diff --git a/typecheck.c b/typecheck.c index 739535c..1018f33 100644 --- a/typecheck.c +++ b/typecheck.c @@ -728,19 +728,21 @@ type_t *get_type(env_t *env, ast_t *ast) type_t *self_value_t = value_type(get_type(env, call->self)); switch (self_value_t->tag) { case ArrayType: { + type_t *item_type = Match(self_value_t, ArrayType)->item_type; if (streq(call->name, "binary_search")) return INT_TYPE; else if (streq(call->name, "by")) return self_value_t; else if (streq(call->name, "clear")) return Type(VoidType); - else if (streq(call->name, "counts")) return Type(TableType, .key_type=Match(self_value_t, ArrayType)->item_type, .value_type=INT_TYPE); + else if (streq(call->name, "counts")) return Type(TableType, .key_type=item_type, .value_type=INT_TYPE); else if (streq(call->name, "find")) return INT_TYPE; + else if (streq(call->name, "first")) return Type(PointerType, .pointed=item_type, .is_optional=true, .is_readonly=true); else if (streq(call->name, "from")) return self_value_t; else if (streq(call->name, "has")) return Type(BoolType); - else if (streq(call->name, "heap_pop")) return Match(self_value_t, ArrayType)->item_type; + else if (streq(call->name, "heap_pop")) return item_type; else if (streq(call->name, "heap_push")) return Type(VoidType); else if (streq(call->name, "heapify")) return Type(VoidType); else if (streq(call->name, "insert")) return Type(VoidType); else if (streq(call->name, "insert_all")) return Type(VoidType); - else if (streq(call->name, "random")) return Match(self_value_t, ArrayType)->item_type; + else if (streq(call->name, "random")) return item_type; else if (streq(call->name, "remove_at")) return Type(VoidType); else if (streq(call->name, "remove_item")) return Type(VoidType); else if (streq(call->name, "reversed")) return self_value_t; @@ -750,7 +752,7 @@ type_t *get_type(env_t *env, ast_t *ast) else if (streq(call->name, "sort")) return Type(VoidType); else if (streq(call->name, "sorted")) return self_value_t; else if (streq(call->name, "to")) return self_value_t; - else if (streq(call->name, "unique")) return Type(SetType, .item_type=Match(self_value_t, ArrayType)->item_type); + else if (streq(call->name, "unique")) return Type(SetType, .item_type=item_type); else code_err(ast, "There is no '%s' method for arrays", call->name); } case SetType: {