More correct version of __ipairs and __pairs when iterating over tables

with holes.
This commit is contained in:
Bruce Hill 2018-04-23 15:16:49 -07:00
parent 86cf46e70b
commit 9144735dbb
2 changed files with 47 additions and 37 deletions

View File

@ -484,6 +484,10 @@ static int Lnexti(lua_State *L)
if (! info) { if (! info) {
luaL_error(L, "invalid type"); luaL_error(L, "invalid type");
} }
int i = lua_isnil(L, 2) ? 1 : lua_tointeger(L, 2)+1;
if (i > (int)info->len) {
return 0;
}
lua_rawgeti(L, -1, info->hash); lua_rawgeti(L, -1, info->hash);
// Stack: [mt, buckets, bucket] // Stack: [mt, buckets, bucket]
if (lua_isnil(L, -1)) { if (lua_isnil(L, -1)) {
@ -493,32 +497,10 @@ static int Lnexti(lua_State *L)
// Stack: [mt, buckets, bucket, inst_udata] // Stack: [mt, buckets, bucket, inst_udata]
lua_rawget(L, -2); lua_rawget(L, -2);
// Stack: [mt, buckets, bucket, inst_table] // Stack: [mt, buckets, bucket, inst_table]
lua_getfield(L, -4, "__fields"); lua_pushinteger(L, i);
// Stack: [mt, buckets, bucket, inst_table, fields] lua_rawgeti(L, -2, i);
if (lua_isnil(L, -1)) { // Stack: [mt, buckets, bucket, inst_table, i, table[i]]
lua_pop(L, 1); return 2;
// Stack: [mt, buckets, bucket, inst_table]
lua_pushvalue(L, 2);
// Stack: [mt, buckets, bucket, inst_table, i]
if (lua_next(L, -2) == 0) {
return 0;
} else {
return 2;
}
} else {
lua_pushvalue(L, 2);
// Stack: [mt, buckets, bucket, inst_table, fields, i]
if (lua_next(L, -2) == 0) {
return 0;
}
// Stack: [mt, buckets, bucket, inst_table, fields, i2, next_fieldname]
lua_pop(L, 1);
// Stack: [mt, buckets, bucket, inst_table, fields, i2]
int i = luaL_checkinteger(L, -1);
lua_rawgeti(L, -3, i);
// Stack: [mt, buckets, bucket, inst_table, fields, i2, value]
return 2;
}
} }
static int Lipairs(lua_State *L) static int Lipairs(lua_State *L)
@ -558,12 +540,14 @@ static int Lnext(lua_State *L)
if (lua_isnil(L, -1)) { if (lua_isnil(L, -1)) {
lua_pop(L, 1); lua_pop(L, 1);
// Stack: [mt, buckets, bucket, inst_table] // Stack: [mt, buckets, bucket, inst_table]
lua_pushvalue(L, 2); int i = lua_isnil(L, 2) ? 1 : lua_tointeger(L, 2)+1;
// Stack: [mt, buckets, bucket, inst_table, k] if (i > (int)info->len) {
if (lua_next(L, -2) == 0) {
return 0; return 0;
} }
// Stack: [mt, buckets, bucket, inst_table, k2, value] lua_pushinteger(L, i);
// Stack: [mt, buckets, bucket, inst_table, k]
lua_rawgeti(L, -2, i);
// Stack: [mt, buckets, bucket, inst_table, k, value]
return 2; return 2;
} else { } else {
lua_pushvalue(L, 2); lua_pushvalue(L, 2);

View File

@ -26,13 +26,13 @@ local function test(description, fn)
io.write(red) io.write(red)
local ok, err = pcall(fn) local ok, err = pcall(fn)
if not ok then if not ok then
io.write(reset..dim.."\r....................................") io.write(reset..dim.."\r.......................................")
io.write(reset.."["..bright..red.."FAILED"..reset.."]\r") io.write(reset.."["..bright..red.."FAILED"..reset.."]\r")
io.write("\r"..description.."\n"..reset) io.write("\r"..description.."\n"..reset)
print(reset..red..(err or "")..reset) print(reset..red..(err or "")..reset)
num_errors = num_errors + 1 num_errors = num_errors + 1
else else
io.write(reset..dim.."\r....................................") io.write(reset..dim.."\r.......................................")
io.write(reset.."["..green.."PASSED"..reset.."]\r") io.write(reset.."["..green.."PASSED"..reset.."]\r")
io.write(description.."\n") io.write(description.."\n")
end end
@ -253,14 +253,40 @@ test("Testing giant immutable table", function()
end) end)
test("Testing tuple iteration", function() test("Testing tuple iteration", function()
local T = immutable(nil) local T = immutable()
local t = T(1,4,9,16) local t = T(1,4,9,nil,16)
local checks = {1,4,9,16} local checks = {1,4,9,nil,16}
local passed = 0
for i,v in ipairs(t) do for i,v in ipairs(t) do
assert(checks[i] == v) assert(checks[i] == v)
checks[i] = nil passed = passed + 1
end end
assert(next(checks) == nil) assert(passed == 5)
passed = 0
for k,v in pairs(t) do
assert(checks[k] == v)
passed = passed + 1
end
assert(passed == 5)
end)
test("Testing table iteration", function()
local Foo = immutable({"x", "y", "z"})
local f = Foo(1,nil,2)
local checks = {1,nil,2}
local passed = 0
for i,v in ipairs(f) do
assert(checks[i] == v)
passed = passed + 1
end
assert(passed == 3)
passed = 0
checks = {x=1,y=nil,z=2}
for k,v in pairs(f) do
assert(checks[k] == v)
passed = passed + 1
end
assert(passed == 3)
end) end)
test("Testing __new", function() test("Testing __new", function()