Fix for stride overflows in arrays

This commit is contained in:
Bruce Hill 2024-08-03 14:40:56 -04:00
parent b2e752ee32
commit 2b9bec18a4
3 changed files with 33 additions and 6 deletions

View File

@ -298,8 +298,28 @@ public array_t Array$to(array_t array, int64_t last)
};
}
public array_t Array$by(array_t array, int64_t stride)
public array_t Array$by(array_t array, int64_t stride, const TypeInfo *type)
{
// In the unlikely event that the stride value would be too large to fit in
// a 15-bit integer, fall back to creating a copy of the array:
if (__builtin_expect(array.stride*stride < MIN_STRIDE || array.stride*stride > MAX_STRIDE, 0)) {
void *copy = NULL;
int64_t item_size = get_item_size(type);
int64_t len = (stride < 0 ? array.length / -stride : array.length / stride) + ((array.length % stride) != 0);
if (len > 0) {
copy = array.atomic ? GC_MALLOC_ATOMIC(len * item_size) : GC_MALLOC(len * item_size);
void *start = (stride < 0 ? array.data + (array.stride * (array.length - 1)) : array.data);
for (int64_t i = 0; i < len; i++)
memcpy(copy + i*item_size, start + array.stride*stride*i, item_size);
}
return (array_t){
.data=copy,
.length=len,
.stride=item_size,
.atomic=array.atomic,
};
}
if (stride == 0)
return (array_t){.atomic=array.atomic};
@ -312,8 +332,15 @@ public array_t Array$by(array_t array, int64_t stride)
};
}
public array_t Array$reversed(array_t array)
public array_t Array$reversed(array_t array, const TypeInfo *type)
{
// Just in case negating the stride gives a value that doesn't fit into a
// 15-bit integer, fall back to Array$by()'s more general method of copying
// the array. This should only happen if array.stride is MIN_STRIDE to
// begin with (very unlikely).
if (__builtin_expect(-array.stride < MIN_STRIDE || -array.stride > MAX_STRIDE, 0))
return Array$by(array, -1, type);
array_t reversed = array;
reversed.stride = -array.stride;
reversed.data = array.data + (array.length-1)*array.stride;

View File

@ -66,8 +66,8 @@ void Array$compact(array_t *arr, const TypeInfo *type);
bool Array$contains(array_t array, void *item, const TypeInfo *type);
array_t Array$from(array_t array, int64_t first);
array_t Array$to(array_t array, int64_t last);
array_t Array$by(array_t array, int64_t stride);
array_t Array$reversed(array_t array);
array_t Array$by(array_t array, int64_t stride, const TypeInfo *type);
array_t Array$reversed(array_t array, const TypeInfo *type);
array_t Array$concat(array_t x, array_t y, const TypeInfo *type);
uint32_t Array$hash(const array_t *arr, const TypeInfo *type);
int32_t Array$compare(const array_t *x, const array_t *y, const TypeInfo *type);

View File

@ -1925,11 +1925,11 @@ CORD compile(env_t *env, ast_t *ast)
} else if (streq(call->name, "by")) {
CORD self = compile_to_pointer_depth(env, call->self, 0, false);
arg_t *arg_spec = new(arg_t, .name="stride", .type=Type(IntType, .bits=64));
return CORD_all("Array$by(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ")");
return CORD_all("Array$by(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ", compile_type_info(env, self_value_t), ")");
} else if (streq(call->name, "reversed")) {
CORD self = compile_to_pointer_depth(env, call->self, 0, false);
(void)compile_arguments(env, ast, NULL, call->args);
return CORD_all("Array$reversed(", self, ")");
return CORD_all("Array$reversed(", self, ", ", compile_type_info(env, self_value_t), ")");
} else code_err(ast, "There is no '%s' method for arrays", call->name);
}
case TableType: {