/*
* immutable.c
* An immutable table library by Bruce Hill. This library returns a single function
* that can be used to declare immutable classes, like so:
*
*     immutable = require 'immutable'
*     local Foo = immutable({"baz","qux"})
*     local foo = Foo("hello", 99)
*     assert(not pcall(function() foo.x = 'mutable' end))
*     local t = {[foo]="it works"}
*     assert(t[Foo("hello", 99)] == "it works")
* 
* Instances *are* garbage collected.
*/

#include "lua.h"
#include "lauxlib.h"

// The C API changed from 5.1 to 5.2, so these shims help the code compile on >=5.2
#if LUA_VERSION_NUM >= 502
#define lua_objlen(L, i) lua_rawlen(L, i)
#define lua_equal(L, i, j) lua_compare(L, i, j, LUA_OPEQ)
#define luaH_getnum(t, k) luaH_getint(t, k)
#define luaL_register(L, _, R) luaL_setfuncs(L, R, 0)
#endif

static int Lcreate_instance(lua_State *L)
{
    int n_args = lua_gettop(L);
    // arg 1: class table, ...

    lua_getfield(L, 1, "__fields");
    // Stack: [__fields]
    size_t n = lua_objlen(L,-1);
    if ((size_t)n_args-1 != n) {
        lua_pushstring(L, "incorrect number of arguments: expected ");
        lua_pushinteger(L, n);
        lua_pushstring(L, ", but got ");
        lua_pushinteger(L, n_args-1);
        lua_concat(L, 4);
        lua_error(L);
    }
    lua_pop(L,1);
    // Stack: []

    lua_createtable(L, n, 0);
    // Stack [inst]
    
    // Copy in all the values, and simultaneously compute the hash:
    lua_Integer hash = 0;
    for (lua_Integer i=1; i <=(lua_Integer)n; i++) {
        lua_pushvalue(L, i+1);
        // Stack [inst, args[i+1]]
        lua_Integer item_hash;
        switch (lua_type(L, -1)) {
            case LUA_TNIL:
                item_hash = 0;
                break;
            case LUA_TNUMBER:
                item_hash = (lua_Integer)lua_tonumber(L, -1);
                break;
            case LUA_TBOOLEAN:
                item_hash = (lua_Integer)lua_toboolean(L, -1);
                break;
            case LUA_TTABLE:
            case LUA_TFUNCTION:
            case LUA_TUSERDATA:
            case LUA_TTHREAD:
            case LUA_TLIGHTUSERDATA:
                item_hash = (lua_Integer)lua_topointer(L, -1);
                break;
            case LUA_TSTRING:
                {
                    size_t strlen;
                    const char *str = lua_tolstring(L, -1, &strlen);
                    item_hash = *str << 7;
                    for (const char *end = &str[strlen]; str < end; str++)
                        item_hash = (1000003*item_hash) ^ *str;
                    item_hash ^= strlen;
                    break;
                }
            default:
                item_hash = 0;
        }
        hash = (1000003 * hash) ^ item_hash;
        lua_rawseti(L, -2, i);
        // Stack [inst]
    }

    lua_getfield(L, 1, "__instances");

    // Stack: [inst, buckets]
    // Find bucket
    lua_rawgeti(L, -1, hash);
    // Stack: [inst, buckets, bucket]
    if (lua_isnil(L, -1)) {
        // Make a new bucket
        // Stack: [inst, buckets, nil]
        lua_pop(L, 1);
        // Stack: [inst, buckets]
        lua_createtable(L, 1, 0);
        // Stack: [inst, buckets, bucket]
        lua_pushlightuserdata(L, (void*)Lcreate_instance);
        lua_gettable(L, LUA_REGISTRYINDEX);
        // Stack: [inst, buckets, bucket, {'__mode'='k'}]
        lua_setmetatable(L, -2);
        // Stack: [inst, buckets, bucket]
        lua_pushvalue(L, -1);
        // Stack: [inst, buckets, bucket, bucket]
        lua_rawseti(L, -3, hash);
        // Stack: [inst, buckets, bucket]
    }
    // Stack: [inst, buckets, bucket]
    // scan bucket
    lua_pushnil(L);
    while (lua_next(L, -2) != 0) { // for hash_collider_inst, hash_collider in pairs(bucket) do
        // Stack: [inst, buckets, bucket, hash_collider_inst, hash_collider]
        int bucket_item_matches = 1;

        // Perform a full equality check
        lua_pushnil(L);
        while (lua_next(L, -2) != 0) { // for collider_key, collider_value in pairs(hash_collider) do
            // Stack: [inst, buckets, bucket, hash_collider_inst, hash_collider, collider_key, collider_value]
            lua_pushvalue(L, -2);
            // Stack: [inst, buckets, bucket, hash_collider_inst, hash_collider, collider_key, collider_value, collider_key]
            lua_gettable(L, -8); // inst[collider_key]
            // Stack: [inst, buckets, bucket, hash_collider_inst, hash_collider, collider_key, collider_value, inst_value]
            if (!lua_rawequal(L, -1, -2)) {
                // go to next item in the bucket
                bucket_item_matches = 0;
                // Stack: [inst, buckets, bucket, hash_collider_inst, hash_collider, collider_key, collider_value, inst_value]
                lua_pop(L, 4);
                // Stack: [inst, buckets, bucket, hash_collider_inst]
                break;
            } else {
                // Stack: [inst, buckets, bucket, hash_collider_inst, hash_collider, collider_key, collider_value, inst_value]
                lua_pop(L, 2);
                // Stack: [inst, buckets, bucket, hash_collider_inst, hash_collider, collider_key]
            }
        }
        if (bucket_item_matches) {
            // Stack: [inst, buckets, bucket, hash_collider_inst, hash_collider]
            lua_pop(L, 1);
            // Found matching singleton
            return 1;
        }
    }

    // failed to find a singleton
    // Stack: [inst, buckets, bucket]

    lua_Integer* userdata = (lua_Integer*)lua_newuserdata(L, sizeof(lua_Integer));
    *userdata = hash;
    
    // Stack [inst, buckets, bucket, inst_userdata]
    lua_pushvalue(L, 1);
    // Stack [inst, buckets, bucket, inst_userdata, metatable]
    lua_setmetatable(L, -2);
    // Stack [inst, buckets, bucket, inst_userdata]
    
    lua_pushvalue(L, -1);
    // Stack [inst, buckets, bucket, inst_userdata, inst_userdata]
    lua_pushvalue(L, -5);
    // Stack [inst, buckets, bucket, inst_userdata, inst_userdata, inst]
    lua_settable(L, -4); // buckets[inst_userdata] = inst
    // Stack [inst, buckets, bucket, inst_userdata]
    return 1;
}

static int Lfrom_table(lua_State *L)
{
    lua_pushvalue(L, 1);
    // Stack: [mt]
    lua_getfield(L, -1, "__fields");
    // Stack: [mt, fields]
    lua_pushnil(L);
    int num_args = 0;
    while (lua_next(L, -2) != 0) {
        // Stack: [mt, fields, i, field_i]
        lua_gettable(L, 2);
        // Stack: [mt, fields, i, table[field_i]]
        lua_insert(L, -3);
        // Stack: [mt, table[field_i], fields, i]
        num_args++;
    }
    // Stack: [mt, table[field], ..., fields]
    lua_pop(L, 1);
    // Stack: [mt, table[field], ...]
    lua_pushcfunction(L, Lcreate_instance);
    // Stack: [mt, table[field], ..., create]
    lua_insert(L, -(num_args+2));
    // Stack: [create, mt, table[field_1], ...]
    lua_call(L, num_args+1, 1);
    return 1;
}

static int Llen(lua_State *L)
{
    lua_getmetatable(L, 1);
    // Stack: [mt]
    lua_getfield(L, -1, "__fields");
    // Stack: [mt, fields]
    lua_pushinteger(L, lua_objlen(L, -1));
    return 1;
}

static int Lindex(lua_State *L)
{
    lua_getmetatable(L, 1);
    // Stack: [mt]
    lua_getfield(L, -1, "__indices");
    // Stack: [mt, indices]
    lua_pushvalue(L, 2);
    // Stack: [mt, indices, k]
    lua_rawget(L, -2);
    // Stack: [mt, indices, i]
    if (! lua_isnil(L, -1)) {
        // Stack: [mt, indices, i]
        lua_getfield(L, -3, "__instances");
        // Stack: [mt, indices, i, buckets]
        lua_rawgeti(L, -1, *((lua_Integer*)lua_touserdata(L, 1)));
        // Stack: [mt, indices, i, buckets, bucket]
        lua_pushvalue(L, 1);
        // Stack: [mt, indices, i, buckets, bucket, inst_udata]
        lua_rawget(L, -2);
        // Stack: [mt, indices, i, buckets, bucket, inst_table]
        lua_rawgeti(L, -1, lua_tointeger(L, -4));
        return 1;
    }
    // Fall back to class:
    // Stack: [mt, indices, i]
    lua_pop(L, 2);
    // Stack: [mt]
    lua_pushvalue(L, 2);
    // Stack: [mt, key]
    lua_gettable(L, -2);
    return 1;
}

static int Ltostring(lua_State *L)
{
    luaL_Buffer b;
    luaL_buffinit(L, &b);

    lua_getmetatable(L, 1);
    // Stack: [mt]
    
    lua_getfield(L, -1, "name");
    if (!lua_isnil(L, -1)) {
        luaL_addvalue(&b);
    } else {
        lua_pop(L, 1);
    }
    luaL_addstring(&b, "(");
    
    lua_getfield(L, -1, "__instances");
    // Stack: [mt, buckets]
    lua_rawgeti(L, -1, *((lua_Integer*)lua_touserdata(L, 1)));
    // Stack: [mt, buckets, bucket]
    lua_pushvalue(L, 1);
    // Stack: [mt, buckets, bucket, inst_udata]
    lua_rawget(L, -2);
    // Stack: [mt, buckets, bucket, inst_table]
    
    lua_getfield(L, -4, "__fields");
    // Stack: [mt, buckets, bucket, inst_table, fields]
    
    lua_pushnil(L);
    int needs_comma = 0;
    while (lua_next(L, -2) != 0) {
        // Stack: [mt, buckets, bucket, inst_table, fields, i, fieldname]
        if (needs_comma) {
            luaL_addstring(&b, ", ");
        } else {
            needs_comma = 1;
        }
        // Stack: [mt, buckets, bucket, inst_table, fields, i, fieldname]
        lua_getglobal(L, "tostring");
        lua_insert(L, -2);
        lua_call(L, 1, 1);
        luaL_addvalue(&b);
        // Stack: [mt, buckets, bucket, inst_table, fields, i]
        luaL_addstring(&b, "=");
        lua_rawgeti(L, -3, lua_tonumber(L, -1));
        // Stack: [mt, buckets, bucket, inst_table, fields, i, value]
        lua_getglobal(L, "tostring");
        lua_insert(L, -2);
        lua_call(L, 1, 1);
        luaL_addvalue(&b);
        // Stack: [mt, buckets, bucket, inst_table, fields, i]
    }
    luaL_addstring(&b, ")");
    luaL_pushresult(&b);
    return 1;
}

static int Lnexti(lua_State *L)
{
    lua_getmetatable(L, 1);
    // Stack: [mt]
    lua_getfield(L, -1, "__instances");
    // Stack: [mt, buckets]
    lua_rawgeti(L, -1, *((lua_Integer*)lua_touserdata(L, 1)));
    // Stack: [mt, buckets, bucket]
    lua_pushvalue(L, 1);
    // Stack: [mt, buckets, bucket, inst_udata]
    lua_rawget(L, -2);
    // Stack: [mt, buckets, bucket, inst_table]
    lua_getfield(L, -4, "__fields");
    // Stack: [mt, buckets, bucket, inst_table, fields]
    lua_pushvalue(L, 2);
    // Stack: [mt, buckets, bucket, inst_table, fields, i]
    if (lua_next(L, -2) == 0) {
        return 0;
    }
    // Stack: [mt, buckets, bucket, inst_table, fields, i2, next_fieldname]
    lua_pop(L, 1);
    // Stack: [mt, buckets, bucket, inst_table, fields, i2]
    lua_rawgeti(L, -3, lua_tonumber(L, -1));
    // Stack: [mt, buckets, bucket, inst_table, fields, i2, value]
    return 2;
}

static int Lipairs(lua_State *L)
{
    lua_pushcfunction(L, Lnexti);
    // Stack: [Lnexti]
    lua_pushvalue(L, 1);
    // Stack: [Lnexti, inst_udata]
    lua_pushnil(L);
    // Stack: [Lnexti, inst_udata, nil]
    return 3;
}

static int Lnext(lua_State *L)
{
    lua_getmetatable(L, 1);
    // Stack: [mt]
    lua_getfield(L, -1, "__instances");
    // Stack: [mt, buckets]
    lua_rawgeti(L, -1, *((lua_Integer*)lua_touserdata(L, 1)));
    // Stack: [mt, buckets, bucket]
    lua_pushvalue(L, 1);
    // Stack: [mt, buckets, bucket, inst_udata]
    lua_rawget(L, -2);
    // Stack: [mt, buckets, bucket, inst_table]
    lua_getfield(L, -4, "__indices");
    // Stack: [mt, buckets, bucket, inst_table, fields]
    lua_pushvalue(L, 2);
    // Stack: [mt, buckets, bucket, inst_table, fields, k]
    if (lua_next(L, -2) == 0) {
        return 0;
    }
    // Stack: [mt, buckets, bucket, inst_table, fields, k2, next_i]
    lua_rawgeti(L, -4, lua_tonumber(L, -1));
    // Stack: [mt, buckets, bucket, inst_table, fields, k2, next_i, value]
    return 3;
}

static int Lpairs(lua_State *L)
{
    lua_pushcfunction(L, Lnext);
    // Stack: [Lnexti]
    lua_pushvalue(L, 1);
    // Stack: [Lnexti, inst_udata]
    lua_pushnil(L);
    // Stack: [Lnexti, inst_udata, nil]
    return 3;
}

static const luaL_Reg R[] =
{
    { "__len", Llen},
    { "__index", Lindex},
    { "__tostring", Ltostring},
    { "__ipairs", Lipairs},
    { "__pairs", Lpairs},
    { "from_table", Lfrom_table},
    { NULL, NULL}
};

static int Lmake_class(lua_State *L)
{
    // args: fields, [methods/metamethods]
    int n_args = lua_gettop(L);

    // CLS = {}
    lua_newtable(L);
    // Stack: [CLS]
    // Populate CLS.__len, CLS.__index, CLS.__pairs, etc.
    luaL_register(L,NULL,R);

    // If methods were passed in, copy them over, overwriting defaults if desired
    if (n_args >= 2) {
        // Stack: [CLS]
        lua_pushnil(L);
        // Stack: [CLS, nil]
        while (lua_next(L, 2) != 0) {
            // Stack: [CLS, method_name, method_value]
            lua_pushvalue(L, -2);
            // Stack: [CLS, method_name, method_value, method_name]
            lua_pushvalue(L, -2);
            // Stack: [CLS, method_name, method_value, method_name, method_value]
            lua_settable(L, -5);
            // Stack: [CLS, method_name, method_value]
            lua_pop(L, 1);
            // Stack: [CLS, method_name]
        }
        // Stack: [CLS]
    }

    // Stack: [CLS]
    lua_newtable(L);
    // Stack: [CLS, CLS.buckets]
    lua_setfield(L, -2, "__instances");

    // Stack: [CLS]

    // CLS.__fields = arg1
    lua_pushvalue(L, 1);
    // Stack: [CLS, __fields]
    lua_setfield(L, -2, "__fields");
    // Stack: [CLS]
    
    size_t n = lua_objlen(L, 1);
    lua_createtable(L, 0, n);
    // Stack: [CLS, __indices]
    lua_pushnil(L);
    while (lua_next(L, 1) != 0) {
        // Stack: [CLS, __indices, i, fieldname]
        lua_pushvalue(L, -2);
        // Stack: [CLS, __indices, i, fieldname, i]
        lua_settable(L, -4);
        // Stack: [CLS, __indices, i]
    }
    lua_setfield(L, -2, "__indices");
    // Stack: [CLS]

    // setmetatable(CLS, {__new=CLS.new})
    lua_createtable(L, 0, 1);
    lua_pushcfunction(L, Lcreate_instance);
    lua_setfield(L, -2, "__call");
    lua_setmetatable(L, -2);
    // Stack [CLS]

    return 1;
}

LUALIB_API int luaopen_immutable(lua_State *L)
{
    lua_pushlightuserdata(L, (void*)Lcreate_instance);
    lua_createtable(L, 0, 1);
    lua_pushstring(L, "k");
    lua_setfield(L, -2, "__mode");
    lua_settable(L, LUA_REGISTRYINDEX);
    lua_pushcfunction(L, Lmake_class);
    return 1;
}