diff --git a/limmutable.c b/limmutable.c index 5e519d3..f6caeaa 100644 --- a/limmutable.c +++ b/limmutable.c @@ -21,6 +21,10 @@ #define lua_objlen(L, i) lua_rawlen(L, i) #define luaL_register(L, _, R) luaL_setfuncs(L, R, 0) #endif +// Lua 5.3 introduced lua_isinteger, fall back to lua_isnumber +#if LUA_VERSION_NUM < 503 +#define lua_isinteger(L, i) lua_isnumber(L, i) +#endif #if !defined(LUAI_HASHLIMIT) #define LUAI_HASHLIMIT 5 @@ -37,21 +41,79 @@ static int WEAK_VALUE_METATABLE; // {__new=Lcreate_instance, __tostring=function(cls) return cls.name or 'immutable(...)' end} static int SHARED_CLASS_METATABLE; +typedef struct { + unsigned long long hash; + size_t len; +} immutable_info_t; + +static int Lhash(lua_State *L) +{ + lua_Integer hash = 0x9a937c4d; // Seed + lua_Integer n = luaL_checkinteger(L, 1); + for (lua_Integer i=2; i <= lua_gettop(L); i++) { + lua_Integer item_hash; + int type = i > n ? LUA_TNIL : lua_type(L, i); + switch (type) { + case LUA_TNIL: case LUA_TNONE: + // Arbitrarily chosen value + item_hash = 0x97167da9; + break; + case LUA_TNUMBER: + { + // Cast float bits to integer + lua_Number num = lua_tonumber(L, i); + item_hash = *((lua_Integer*)&num); + if (item_hash == 0) { + item_hash = 0x2887c992; + } + break; + } + case LUA_TBOOLEAN: + // Arbitrarily chosen values + item_hash = lua_toboolean(L, i)? 0x82684f71 : 0x88d66f2a; + break; + case LUA_TTABLE: + case LUA_TFUNCTION: + case LUA_TUSERDATA: + case LUA_TTHREAD: + case LUA_TLIGHTUSERDATA: + item_hash = (1000003 * type) ^ (lua_Integer)lua_topointer(L, i); + break; + case LUA_TSTRING: + { + // Algorithm taken from Lua 5.3's implementation + size_t len; + const char *str = lua_tolstring(L, i, &len); + item_hash = len ^ 0xd2e9e9ac; // Arbitrary seed + size_t step = (len >> LUAI_HASHLIMIT) + 1; + for (; len >= step; len -= step) + item_hash ^= ((item_hash<<5) + (item_hash>>2) + (unsigned char)(str[len - 1])); + break; + } + default: + item_hash = 0; + } + hash = (1000003 * hash) ^ item_hash; + } + lua_pushinteger(L, hash); + return 1; +} + static int Lcreate_instance(lua_State *L) { size_t n_args = lua_gettop(L)-1; // arg 1: class table, ... lua_getfield(L, 1, "__fields"); - size_t n = lua_objlen(L, -1); + size_t n = lua_isnil(L, -1) ? n_args : lua_objlen(L, -1); if (n_args > n) { luaL_error(L, "Too many args: expected %d, but got %d", n, n_args); } lua_pop(L, 1); // Compute the hash: - lua_Integer hash = (lua_Integer)lua_topointer(L, 1); // Hash depends on the metatable used in creation + unsigned long long hash = 0x9a937c4d; // Seed for (lua_Integer i=1; i <=(lua_Integer)n; i++) { - lua_Integer item_hash; + unsigned long long item_hash; int type = n > n_args ? LUA_TNIL : lua_type(L, 1+i); switch (type) { case LUA_TNIL: case LUA_TNONE: @@ -63,6 +125,9 @@ static int Lcreate_instance(lua_State *L) // Cast float bits to integer lua_Number num = lua_tonumber(L, 1+i); item_hash = *((lua_Integer*)&num); + if (item_hash == 0) { + item_hash = 0x2887c992; + } break; } case LUA_TBOOLEAN: @@ -74,7 +139,7 @@ static int Lcreate_instance(lua_State *L) case LUA_TUSERDATA: case LUA_TTHREAD: case LUA_TLIGHTUSERDATA: - item_hash = (lua_Integer)lua_topointer(L, 1+i); + item_hash = (1000003 * type) ^ (lua_Integer)lua_topointer(L, i); break; case LUA_TSTRING: { @@ -152,8 +217,9 @@ static int Lcreate_instance(lua_State *L) // Failed to find an existing instance, so create a new one // Stack: [buckets, bucket] - lua_Integer* userdata = (lua_Integer*)lua_newuserdata(L, sizeof(lua_Integer)); - *userdata = hash; + immutable_info_t *userdata = (immutable_info_t*)lua_newuserdata(L, sizeof(immutable_info_t)); + userdata->hash = hash; + userdata->len = n; // Stack [buckets, bucket, inst_userdata] lua_pushvalue(L, 1); @@ -199,29 +265,44 @@ static int Lfrom_table(lua_State *L) lua_pushvalue(L, 1); // Stack: [mt] lua_getfield(L, -1, "__fields"); - int n = lua_objlen(L, -1); - if (! lua_checkstack(L, n)) { - luaL_error(L, "Insufficient stack space!"); + int n; + if (lua_isnil(L, -1)) { + lua_getfield(L, 2, "n"); + if (lua_isnil(L, -1)) { + luaL_error(L, "table needs an 'n' field to track its length"); + } + n = luaL_checkinteger(L, -1); + lua_pop(L, 1); + if (! lua_checkstack(L, n)) { + luaL_error(L, "Insufficient stack space!"); + } + for (int i = 1; i <= n; i++) { + lua_rawgeti(L, 2, i); + } + // Stack: [mt, table[1], table[2], ... table[table.n]] + } else { + n = lua_objlen(L, -1); + if (! lua_checkstack(L, n)) { + luaL_error(L, "Insufficient stack space!"); + } + // Stack: [mt, fields] + lua_pushnil(L); + while (lua_next(L, -2) != 0) { + // Stack: [mt, fields, i, field_i] + lua_gettable(L, 2); + // Stack: [mt, fields, i, table[field_i]] + lua_insert(L, -3); + // Stack: [mt, table[field_i], fields, i] + } + // Stack: [mt, table[field], ..., fields] + lua_pop(L, 1); + // Stack: [mt, table[field], ...] } - // Stack: [mt, fields] - lua_pushnil(L); - int num_args = 0; - while (lua_next(L, -2) != 0) { - // Stack: [mt, fields, i, field_i] - lua_gettable(L, 2); - // Stack: [mt, fields, i, table[field_i]] - lua_insert(L, -3); - // Stack: [mt, table[field_i], fields, i] - num_args++; - } - // Stack: [mt, table[field], ..., fields] - lua_pop(L, 1); - // Stack: [mt, table[field], ...] lua_pushcfunction(L, Lcreate_instance); // Stack: [mt, table[field], ..., create] - lua_insert(L, -(num_args+2)); + lua_insert(L, -(n+2)); // Stack: [create, mt, table[field_1], ...] - lua_call(L, num_args+1, 1); + lua_call(L, n+1, 1); return 1; } @@ -238,57 +319,32 @@ static int Lis_instance(lua_State *L) static int Llen(lua_State *L) { - if (! lua_getmetatable(L, 1)) { - luaL_error(L, "invalid type"); - } - // Stack: [mt] - lua_getfield(L, -1, "__fields"); - // Stack: [mt, fields] - lua_pushinteger(L, lua_objlen(L, -1)); + luaL_checktype(L, 1, LUA_TUSERDATA); + immutable_info_t *info = (immutable_info_t *)lua_touserdata(L, 1); + lua_pushinteger(L, info->len); return 1; } static int Lindex(lua_State *L) { + luaL_checktype(L, 1, LUA_TUSERDATA); if (! lua_getmetatable(L, 1)) { luaL_error(L, "invalid type"); } + immutable_info_t *info = (immutable_info_t*)lua_touserdata(L, 1); + if (! info) { + luaL_error(L, "invalid type"); + } + // Stack: [mt] lua_getfield(L, -1, "__indices"); // Stack: [mt, indices] - lua_pushvalue(L, 2); - // Stack: [mt, indices, k] - lua_gettable(L, -2); - // Stack: [mt, indices, i] - if (! lua_isnil(L, -1)) { // Found the field name - // Stack: [mt, indices, i] - lua_getfield(L, -3, "__instances"); - // Stack: [mt, indices, i, buckets] - lua_Integer* hash_address = (lua_Integer*)lua_touserdata(L, 1); - if (! hash_address) { - luaL_error(L, "invalid type"); - } - lua_rawgeti(L, -1, *hash_address); - // Stack: [mt, indices, i, buckets, bucket] - if (lua_isnil(L, -1)) { - luaL_error(L, "Failed to find hash bucket for hash: %p", (void*)*hash_address); - } - lua_pushvalue(L, 1); - // Stack: [mt, indices, i, buckets, bucket, inst_udata] - lua_rawget(L, -2); - // Stack: [mt, indices, i, buckets, bucket, inst_table] - lua_rawgeti(L, -1, lua_tointeger(L, -4)); - return 1; - } else if (lua_type(L, 2) == LUA_TNUMBER) { - lua_pop(L, 2); + if (lua_isnil(L, -1)) { + lua_pop(L, 1); // Stack: [mt] lua_getfield(L, -1, "__instances"); // Stack: [mt, buckets] - lua_Integer* hash_address = (lua_Integer*)lua_touserdata(L, 1); - if (! hash_address) { - luaL_error(L, "invalid type"); - } - lua_rawgeti(L, -1, *hash_address); + lua_rawgeti(L, -1, info->hash); // Stack: [mt, buckets, bucket] if (lua_isnil(L, -1)) { luaL_error(L, "Failed to find hash bucket"); @@ -297,21 +353,43 @@ static int Lindex(lua_State *L) // Stack: [mt, buckets, bucket, inst_udata] lua_rawget(L, -2); // Stack: [mt, buckets, bucket, inst_table] - lua_rawgeti(L, -1, lua_tointeger(L, 2)); - if (! lua_isnil(L, -1)) { - // Found numeric index - return 1; - } else { - // Fall back to class - // Stack: [mt, buckets, bucket, inst_table, v, k] + int i = luaL_checkinteger(L, 2); + lua_rawgeti(L, -1, i); + return 1; + } + + // Stack: [mt, indices] + lua_pushvalue(L, 2); + // Stack: [mt, indices, k] + lua_gettable(L, -2); + // Stack: [mt, indices, i] + if (lua_isnil(L, -1) && lua_isinteger(L, 2)) { + int i = lua_tointeger(L, 2); + if (1 <= i && i <= (int)info->len) { + // Use the raw value of i + lua_pop(L, 1); lua_pushvalue(L, 2); - // Stack: [mt, key] - lua_gettable(L, -6); - return 1; } + } + if (! lua_isnil(L, -1)) { // Found the field name + // Stack: [mt, indices, i] + lua_getfield(L, -3, "__instances"); + // Stack: [mt, indices, i, buckets] + lua_rawgeti(L, -1, info->hash); + // Stack: [mt, indices, i, buckets, bucket] + if (lua_isnil(L, -1)) { + luaL_error(L, "Failed to find hash bucket for hash: %p", (void*)info->hash); + } + lua_pushvalue(L, 1); + // Stack: [mt, indices, i, buckets, bucket, inst_udata] + lua_rawget(L, -2); + // Stack: [mt, indices, i, buckets, bucket, inst_table] + int i = luaL_checkinteger(L, -4); + lua_rawgeti(L, -1, i); + return 1; } else { // Fall back to class: - // Stack: [mt, indices, i] + // Stack: [mt, indices, nil] lua_pop(L, 2); // Stack: [mt] lua_pushvalue(L, 2); @@ -341,11 +419,11 @@ static int Ltostring(lua_State *L) lua_getfield(L, -1, "__instances"); // Stack: [mt, buckets] - lua_Integer* hash_address = (lua_Integer*)lua_touserdata(L, 1); - if (! hash_address) { + immutable_info_t *info = (immutable_info_t*)lua_touserdata(L, 1); + if (! info) { luaL_error(L, "invalid type"); } - lua_rawgeti(L, -1, *hash_address); + lua_rawgeti(L, -1, info->hash); // Stack: [mt, buckets, bucket] if (lua_isnil(L, -1)) { luaL_error(L, "Failed to find hash bucket"); @@ -360,43 +438,59 @@ static int Ltostring(lua_State *L) int tostring_index = lua_gettop(L); lua_getfield(L, -5, "__fields"); // Stack: [mt, buckets, bucket, inst_table, tostring, fields] - int num_fields = lua_objlen(L, -1); - int fields_index = lua_gettop(L); - - int needs_comma = 0; - for (int i = 1; i <= num_fields; i++) { - // Stack: [mt, buckets, bucket, inst_table, tostring, fields, ...] - if (needs_comma) { - luaL_addstring(&b, ", "); - } else { - needs_comma = 1; - } - lua_rawgeti(L, fields_index, i); - // Stack: [mt, buckets, bucket, inst_table, tostring, fields, ..., fieldname] - if (lua_type(L, -1) == LUA_TNUMBER && lua_tointeger(L, -1) == i) { - lua_pop(L, 1); - // Stack: [mt, buckets, bucket, inst_table, tostring, fields, ...] - } else { + if (lua_isnil(L, -1)) { + lua_pop(L, 1); + // Stack: [mt, buckets, bucket, inst_table, tostring] + immutable_info_t *info = (immutable_info_t*)lua_touserdata(L, 1); + int n = info->len; + int needs_comma = 0; + for (int i = 1; i <= n; i++) { + // Stack: [mt, buckets, bucket, inst_table, tostring, ???] + if (needs_comma) { + luaL_addstring(&b, ", "); + } else { + needs_comma = 1; + } lua_pushvalue(L, tostring_index); - // Stack: [mt, buckets, bucket, inst_table, tostring, fields, ..., fieldname, tostring] - lua_insert(L, -2); - // Stack: [mt, buckets, bucket, inst_table, tostring, fields, ..., tostring, fieldname] + // Stack: [mt, buckets, bucket, inst_table, tostring, ???, tostring] + lua_rawgeti(L, inst_table_index, i); + // Stack: [mt, buckets, bucket, inst_table, tostring, ???, tostring, value] lua_call(L, 1, 1); - // Stack: [mt, buckets, bucket, inst_table, tostring, fields, ..., field string] + // Stack: [mt, buckets, bucket, inst_table, tostring, ???, value string] luaL_addvalue(&b); - // Stack: [mt, buckets, bucket, inst_table, tostring, fields, ...] - luaL_addstring(&b, "="); + // Stack: [mt, buckets, bucket, inst_table, tostring, ???] + } + } else { + int num_fields = lua_objlen(L, -1); + int fields_index = lua_gettop(L); + int needs_comma = 0; + for (int i = 1; i <= num_fields; i++) { + // Stack: [mt, buckets, bucket, inst_table, tostring, fields, ???] + if (needs_comma) { + luaL_addstring(&b, ", "); + } else { + needs_comma = 1; + } + lua_pushvalue(L, tostring_index); + // Stack: [mt, buckets, bucket, inst_table, tostring, fields, ???, tostring] + lua_rawgeti(L, fields_index, i); + // Stack: [mt, buckets, bucket, inst_table, tostring, fields, ???, tostring, fieldname] + lua_call(L, 1, 1); + // Stack: [mt, buckets, bucket, inst_table, tostring, fields, ???, field string] + luaL_addvalue(&b); + // Stack: [mt, buckets, bucket, inst_table, tostring, fields, ???] + + luaL_addstring(&b, "="); + + lua_pushvalue(L, tostring_index); + // Stack: [mt, buckets, bucket, inst_table, tostring, fields, ???, tostring] + lua_rawgeti(L, inst_table_index, i); + // Stack: [mt, buckets, bucket, inst_table, tostring, fields, ???, tostring, value] + lua_call(L, 1, 1); + // Stack: [mt, buckets, bucket, inst_table, tostring, fields, ???, value string] + luaL_addvalue(&b); + // Stack: [mt, buckets, bucket, inst_table, tostring, fields, ???] } - lua_rawgeti(L, inst_table_index, i); - // Stack: [mt, buckets, bucket, inst_table, tostring, fields, ..., value] - lua_pushvalue(L, tostring_index); - // Stack: [mt, buckets, bucket, inst_table, tostring, fields, ..., value, tostring] - lua_insert(L, -2); - // Stack: [mt, buckets, bucket, inst_table, tostring, fields, ..., tostring, value] - lua_call(L, 1, 1); - // Stack: [mt, buckets, bucket, inst_table, tostring, fields, ..., value string] - luaL_addvalue(&b); - // Stack: [mt, buckets, bucket, inst_table, tostring, fields, ...] } luaL_addstring(&b, ")"); luaL_pushresult(&b); @@ -411,11 +505,11 @@ static int Lnexti(lua_State *L) // Stack: [mt] lua_getfield(L, -1, "__instances"); // Stack: [mt, buckets] - lua_Integer* hash_address = (lua_Integer*)lua_touserdata(L, 1); - if (! hash_address) { + immutable_info_t *info = (immutable_info_t*)lua_touserdata(L, 1); + if (! info) { luaL_error(L, "invalid type"); } - lua_rawgeti(L, -1, *hash_address); + lua_rawgeti(L, -1, info->hash); // Stack: [mt, buckets, bucket] if (lua_isnil(L, -1)) { luaL_error(L, "Failed to find hash bucket"); @@ -426,17 +520,30 @@ static int Lnexti(lua_State *L) // Stack: [mt, buckets, bucket, inst_table] lua_getfield(L, -4, "__fields"); // Stack: [mt, buckets, bucket, inst_table, fields] - lua_pushvalue(L, 2); - // Stack: [mt, buckets, bucket, inst_table, fields, i] - if (lua_next(L, -2) == 0) { - return 0; + if (lua_isnil(L, -1)) { + lua_pop(L, 1); + // Stack: [mt, buckets, bucket, inst_table] + lua_pushvalue(L, 2); + // Stack: [mt, buckets, bucket, inst_table, i] + if (lua_next(L, -2) == 0) { + return 0; + } else { + return 2; + } + } else { + lua_pushvalue(L, 2); + // Stack: [mt, buckets, bucket, inst_table, fields, i] + if (lua_next(L, -2) == 0) { + return 0; + } + // Stack: [mt, buckets, bucket, inst_table, fields, i2, next_fieldname] + lua_pop(L, 1); + // Stack: [mt, buckets, bucket, inst_table, fields, i2] + int i = luaL_checkinteger(L, -1); + lua_rawgeti(L, -3, i); + // Stack: [mt, buckets, bucket, inst_table, fields, i2, value] + return 2; } - // Stack: [mt, buckets, bucket, inst_table, fields, i2, next_fieldname] - lua_pop(L, 1); - // Stack: [mt, buckets, bucket, inst_table, fields, i2] - lua_rawgeti(L, -3, lua_tonumber(L, -1)); - // Stack: [mt, buckets, bucket, inst_table, fields, i2, value] - return 2; } static int Lipairs(lua_State *L) @@ -458,11 +565,11 @@ static int Lnext(lua_State *L) // Stack: [mt] lua_getfield(L, -1, "__instances"); // Stack: [mt, buckets] - lua_Integer* hash_address = (lua_Integer*)lua_touserdata(L, 1); - if (! hash_address) { + immutable_info_t *info = (immutable_info_t*)lua_touserdata(L, 1); + if (! info) { luaL_error(L, "invalid type"); } - lua_rawgeti(L, -1, *hash_address); + lua_rawgeti(L, -1, info->hash); // Stack: [mt, buckets, bucket] if (lua_isnil(L, -1)) { luaL_error(L, "Failed to find hash bucket"); @@ -473,25 +580,37 @@ static int Lnext(lua_State *L) // Stack: [mt, buckets, bucket, inst_table] lua_getfield(L, -4, "__indices"); // Stack: [mt, buckets, bucket, inst_table, fields] - lua_pushvalue(L, 2); - // Stack: [mt, buckets, bucket, inst_table, fields, k] - if (lua_next(L, -2) == 0) { - return 0; + if (lua_isnil(L, -1)) { + lua_pop(L, 1); + // Stack: [mt, buckets, bucket, inst_table] + lua_pushvalue(L, 2); + // Stack: [mt, buckets, bucket, inst_table, k] + if (lua_next(L, -2) == 0) { + return 0; + } + // Stack: [mt, buckets, bucket, inst_table, k2, value] + return 2; + } else { + lua_pushvalue(L, 2); + // Stack: [mt, buckets, bucket, inst_table, fields, k] + if (lua_next(L, -2) == 0) { + return 0; + } + // Stack: [mt, buckets, bucket, inst_table, fields, k2, next_i] + lua_gettable(L, -4); + // Stack: [mt, buckets, bucket, inst_table, fields, k2, value] + return 2; } - // Stack: [mt, buckets, bucket, inst_table, fields, k2, next_i] - lua_gettable(L, -4); - // Stack: [mt, buckets, bucket, inst_table, fields, k2, value] - return 2; } static int Lpairs(lua_State *L) { lua_pushcfunction(L, Lnext); - // Stack: [Lnexti] + // Stack: [Lnext] lua_pushvalue(L, 1); - // Stack: [Lnexti, inst_udata] + // Stack: [Lnext, inst_udata] lua_pushnil(L); - // Stack: [Lnexti, inst_udata, nil] + // Stack: [Lnext, inst_udata, nil] return 3; } @@ -577,37 +696,8 @@ static int Lmake_class(lua_State *L) lua_setfield(L, -2, "__indices"); break; } - case LUA_TNUMBER: { - // If no fields were passed in, make them empty (i.e. a singleton) - lua_Integer n = lua_tointeger(L, 1); - if (n < 0) { - luaL_error(L, "immutable table size must be positive"); - } - lua_createtable(L, n, 0); - lua_createtable(L, n, 0); - // Stack: [CLS, __fields, __indices] - for (lua_Integer i = 1; i <= n; i++) { - lua_pushinteger(L, i); - // Stack: [CLS, __fields, __indices, i] - lua_rawseti(L, -2, i); - // Stack: [CLS, __fields, __indices] - lua_pushinteger(L, i); - // Stack: [CLS, __fields, __indices, i] - lua_rawseti(L, -3, i); - // Stack: [CLS, __fields, __indices] - } - // Stack: [CLS, __fields, __indices] - lua_setfield(L, -3, "__indices"); - // Stack: [CLS, __fields] - lua_setfield(L, -2, "__fields"); - break; - } case LUA_TNIL: case LUA_TNONE: { - // If no fields were passed in, make them empty (i.e. a singleton) - lua_createtable(L, 0, 0); - lua_setfield(L, -2, "__fields"); - lua_createtable(L, 0, 0); - lua_setfield(L, -2, "__indices"); + // If no fields were passed in, so leave __fields and __indices empty break; } default: { @@ -659,5 +749,7 @@ LUALIB_API int luaopen_immutable(lua_State *L) lua_settable(L, LUA_REGISTRYINDEX); lua_pushcfunction(L, Lmake_class); + lua_pushcfunction(L, Lhash); + lua_setglobal(L, "extract_hash"); return 1; } diff --git a/tests.lua b/tests.lua index 5cc32f7..c6dec3f 100644 --- a/tests.lua +++ b/tests.lua @@ -215,26 +215,26 @@ if _VERSION == "Lua 5.3" then end test("Testing immutable(n)", function() - local Tup3 = immutable(3, {name="Tuple"}) + local Tup3 = immutable(nil, {name="Tuple"}) assert(tostring(Tup3(1,2,3)) == "Tuple(1, 2, 3)") end) test("Testing tostring(class)", function() - local C1 = immutable(0, {name="MYNAME"}) + local C1 = immutable(nil, {name="MYNAME"}) assert(tostring(C1) == "MYNAME") local C2 = immutable() assert(tostring(C2):match("immutable type: 0x.*")) end) test("Testing tuple tostring", function() - local tup3 = immutable(3) + local tup3 = immutable(nil) assert(tostring(tup3(1,2,3)) == "(1, 2, 3)") assert(tostring(tup3(1,tup3(2,3,4),5)) == "(1, (2, 3, 4), 5)") end) test("Testing giant immutable table", function() local keys = {} - local N = 100000 + local N = 10000 for i=1,N do keys[i] = "key_"..tostring(i) end local T = immutable(keys) local values = {} @@ -247,5 +247,5 @@ if num_errors == 0 then else print(bright..red.."*** "..tostring(num_errors).." test"..(num_errors > 1 and "s" or "").." failed! ***"..reset) io.write(reset) - os.exit(false, true) + error() end