diff options
| -rw-r--r-- | builtins/array.c | 62 | ||||
| -rw-r--r-- | builtins/array.h | 3 | ||||
| -rw-r--r-- | compile.c | 12 | ||||
| -rw-r--r-- | test/arrays.tm | 25 | ||||
| -rw-r--r-- | typecheck.c | 3 |
5 files changed, 61 insertions, 44 deletions
diff --git a/builtins/array.c b/builtins/array.c index 69d5cff1..1c31be80 100644 --- a/builtins/array.c +++ b/builtins/array.c @@ -261,51 +261,39 @@ public array_t Array$sample(array_t arr, int64_t n, array_t weights, const TypeI return selected; } -public array_t Array$slice(array_t *array, int64_t first, int64_t length, int64_t stride, const TypeInfo *type) +public array_t Array$from(array_t *array, int64_t first, int64_t last) { - if (stride > MAX_STRIDE || stride < MIN_STRIDE) - fail("Stride is too big: %ld", stride); + if (first < 0) + first = array->length + first + 1; - if (stride == 0 || length <= 0) { - // Zero stride + if (last < 0) + last = array->length + last + 1; + + if (first < 1 || first > array->length || last < first) return (array_t){.atomic=array->atomic}; - } else if (stride < 0) { - if (first == INT64_MIN) first = array->length; - if (first > array->length) { - // Range starting after array - int64_t residual = first % -stride; - first = array->length - (array->length % -stride) + residual; - } - if (first > array->length) first += stride; - if (first < 1) { - // Range outside array - return (array_t){.atomic=array->atomic}; - } - } else { - if (first == INT64_MIN) first = 1; - if (first < 1) { - // Range starting before array - first = first % stride; - } - while (first < 1) first += stride; - if (first > array->length) { - // Range outside array - return (array_t){.atomic=array->atomic}; - } - } - if (length > array->length/labs(stride) + 1) length = array->length/labs(stride) + 1; - if (length < 0) length = -length; + if (last > array->length) + last = array->length; + + return (array_t){ + .atomic=array->atomic, + .data=array->data + array->stride*(first-1), + .length=last - first + 1, + .stride=array->stride, + .data_refcount=array->data_refcount, + }; +} - // Saturating add: - array->data_refcount |= (array->data_refcount << 1) | 1; +public array_t Array$by(array_t *array, int64_t stride) +{ + if (stride == 0) + return (array_t){.atomic=array->atomic}; - int64_t item_size = get_item_size(type); return (array_t){ .atomic=array->atomic, - .data=array->data + item_size*(first-1), - .length=length, - .stride=(array->stride * stride), + .data=(stride < 0 ? array->data + (array->stride * (array->length - 1)) : array->data), + .length=(stride < 0 ? array->length / -stride : array->length / stride) + ((array->length % stride) != 0), + .stride=array->stride * stride, .data_refcount=array->data_refcount, }; } diff --git a/builtins/array.h b/builtins/array.h index fbf37a53..b2884d99 100644 --- a/builtins/array.h +++ b/builtins/array.h @@ -64,7 +64,8 @@ array_t Array$sample(array_t arr, int64_t n, array_t weights, const TypeInfo *ty void Array$clear(array_t *array); void Array$compact(array_t *arr, const TypeInfo *type); bool Array$contains(array_t array, void *item, const TypeInfo *type); -array_t Array$slice(array_t *array, int64_t first, int64_t length, int64_t stride, const TypeInfo *type); +array_t Array$from(array_t *array, int64_t first, int64_t last); +array_t Array$by(array_t *array, int64_t stride); array_t Array$reversed(array_t array); array_t Array$concat(array_t x, array_t y, const TypeInfo *type); uint32_t Array$hash(const array_t *arr, const TypeInfo *type); @@ -1813,13 +1813,15 @@ CORD compile(env_t *env, ast_t *ast) CORD self = compile_to_pointer_depth(env, call->self, 1, false); (void)compile_arguments(env, ast, NULL, call->args); return CORD_all("Array$clear(", self, ")"); - } else if (streq(call->name, "slice")) { + } else if (streq(call->name, "from")) { CORD self = compile_to_pointer_depth(env, call->self, 1, false); arg_t *arg_spec = new(arg_t, .name="first", .type=Type(IntType, .bits=64), .default_val=FakeAST(Int, .i=1, .bits=64), - .next=new(arg_t, .name="length", .type=Type(IntType, .bits=64), .default_val=FakeAST(Int, .i=INT64_MAX, .bits=64), - .next=new(arg_t, .name="stride", .type=Type(IntType, .bits=64), .default_val=FakeAST(Int, .i=1, .bits=64)))); - return CORD_all("Array$slice(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ", - compile_type_info(env, self_value_t), ")"); + .next=new(arg_t, .name="last", .type=Type(IntType, .bits=64), .default_val=FakeAST(Int, .i=-1, .bits=64))); + return CORD_all("Array$from(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ")"); + } else if (streq(call->name, "by")) { + CORD self = compile_to_pointer_depth(env, call->self, 1, false); + arg_t *arg_spec = new(arg_t, .name="stride", .type=Type(IntType, .bits=64)); + return CORD_all("Array$by(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ")"); } else if (streq(call->name, "reversed")) { CORD self = compile_to_pointer_depth(env, call->self, 0, false); (void)compile_arguments(env, ast, NULL, call->args); diff --git a/test/arrays.tm b/test/arrays.tm index b609aa9b..9eea9dd9 100644 --- a/test/arrays.tm +++ b/test/arrays.tm @@ -118,3 +118,28 @@ func main(): >> sorted == sorted:sorted() = yes + do: + >> [i*10 for i in 5]:from(3) + = [30, 40, 50] + >> [i*10 for i in 5]:from(last=3) + = [10, 20, 30] + >> [i*10 for i in 5]:from(last=-2) + = [10, 20, 30, 40] + >> [i*10 for i in 5]:from(-2) + = [40, 50] + + >> [i*10 for i in 5]:by(2) + = [10, 30, 50] + >> [i*10 for i in 5]:by(-1) + = [50, 40, 30, 20, 10] + + >> [10, 20, 30, 40]:by(2) + = [10, 30] + >> [10, 20, 30, 40]:by(-2) + = [40, 20] + + >> [i*10 for i in 10]:by(2):by(2) + = [10, 50, 90] + + >> [i*10 for i in 10]:by(2):by(-1) + = [90, 70, 50, 30, 10] diff --git a/typecheck.c b/typecheck.c index 9a93211d..5cc9d664 100644 --- a/typecheck.c +++ b/typecheck.c @@ -669,7 +669,8 @@ type_t *get_type(env_t *env, ast_t *ast) return Type(PointerType, .pointed=Match(self_value_t, ArrayType)->item_type, .is_optional=true, .is_readonly=true); else if (streq(call->name, "sample")) return self_value_t; else if (streq(call->name, "clear")) return Type(VoidType); - else if (streq(call->name, "slice")) return self_value_t; + else if (streq(call->name, "from")) return self_value_t; + else if (streq(call->name, "by")) return self_value_t; else if (streq(call->name, "reversed")) return self_value_t; else if (streq(call->name, "heapify")) return Type(VoidType); else if (streq(call->name, "heap_push")) return Type(VoidType); |
