diff --git a/limmutable.c b/limmutable.c index 8a7b782..7a5f704 100644 --- a/limmutable.c +++ b/limmutable.c @@ -484,6 +484,10 @@ static int Lnexti(lua_State *L) if (! info) { luaL_error(L, "invalid type"); } + int i = lua_isnil(L, 2) ? 1 : lua_tointeger(L, 2)+1; + if (i > (int)info->len) { + return 0; + } lua_rawgeti(L, -1, info->hash); // Stack: [mt, buckets, bucket] if (lua_isnil(L, -1)) { @@ -493,32 +497,10 @@ static int Lnexti(lua_State *L) // Stack: [mt, buckets, bucket, inst_udata] lua_rawget(L, -2); // Stack: [mt, buckets, bucket, inst_table] - lua_getfield(L, -4, "__fields"); - // Stack: [mt, buckets, bucket, inst_table, fields] - if (lua_isnil(L, -1)) { - lua_pop(L, 1); - // Stack: [mt, buckets, bucket, inst_table] - lua_pushvalue(L, 2); - // Stack: [mt, buckets, bucket, inst_table, i] - if (lua_next(L, -2) == 0) { - return 0; - } else { - return 2; - } - } else { - lua_pushvalue(L, 2); - // Stack: [mt, buckets, bucket, inst_table, fields, i] - if (lua_next(L, -2) == 0) { - return 0; - } - // Stack: [mt, buckets, bucket, inst_table, fields, i2, next_fieldname] - lua_pop(L, 1); - // Stack: [mt, buckets, bucket, inst_table, fields, i2] - int i = luaL_checkinteger(L, -1); - lua_rawgeti(L, -3, i); - // Stack: [mt, buckets, bucket, inst_table, fields, i2, value] - return 2; - } + lua_pushinteger(L, i); + lua_rawgeti(L, -2, i); + // Stack: [mt, buckets, bucket, inst_table, i, table[i]] + return 2; } static int Lipairs(lua_State *L) @@ -558,12 +540,14 @@ static int Lnext(lua_State *L) if (lua_isnil(L, -1)) { lua_pop(L, 1); // Stack: [mt, buckets, bucket, inst_table] - lua_pushvalue(L, 2); - // Stack: [mt, buckets, bucket, inst_table, k] - if (lua_next(L, -2) == 0) { + int i = lua_isnil(L, 2) ? 1 : lua_tointeger(L, 2)+1; + if (i > (int)info->len) { return 0; } - // Stack: [mt, buckets, bucket, inst_table, k2, value] + lua_pushinteger(L, i); + // Stack: [mt, buckets, bucket, inst_table, k] + lua_rawgeti(L, -2, i); + // Stack: [mt, buckets, bucket, inst_table, k, value] return 2; } else { lua_pushvalue(L, 2); diff --git a/tests.lua b/tests.lua index f201cb3..7f5d2b9 100644 --- a/tests.lua +++ b/tests.lua @@ -26,13 +26,13 @@ local function test(description, fn) io.write(red) local ok, err = pcall(fn) if not ok then - io.write(reset..dim.."\r....................................") + io.write(reset..dim.."\r.......................................") io.write(reset.."["..bright..red.."FAILED"..reset.."]\r") io.write("\r"..description.."\n"..reset) print(reset..red..(err or "")..reset) num_errors = num_errors + 1 else - io.write(reset..dim.."\r....................................") + io.write(reset..dim.."\r.......................................") io.write(reset.."["..green.."PASSED"..reset.."]\r") io.write(description.."\n") end @@ -253,14 +253,40 @@ test("Testing giant immutable table", function() end) test("Testing tuple iteration", function() - local T = immutable(nil) - local t = T(1,4,9,16) - local checks = {1,4,9,16} + local T = immutable() + local t = T(1,4,9,nil,16) + local checks = {1,4,9,nil,16} + local passed = 0 for i,v in ipairs(t) do assert(checks[i] == v) - checks[i] = nil + passed = passed + 1 end - assert(next(checks) == nil) + assert(passed == 5) + passed = 0 + for k,v in pairs(t) do + assert(checks[k] == v) + passed = passed + 1 + end + assert(passed == 5) +end) + +test("Testing table iteration", function() + local Foo = immutable({"x", "y", "z"}) + local f = Foo(1,nil,2) + local checks = {1,nil,2} + local passed = 0 + for i,v in ipairs(f) do + assert(checks[i] == v) + passed = passed + 1 + end + assert(passed == 3) + passed = 0 + checks = {x=1,y=nil,z=2} + for k,v in pairs(f) do + assert(checks[k] == v) + passed = passed + 1 + end + assert(passed == 3) end) test("Testing __new", function()