-- A collection of helper utility functions
--
local match, gmatch, gsub = string.match, string.gmatch, string.gsub
local function is_list(t)
    if type(t) ~= 'table' then
        return false
    end
    local i = 1
    for _ in pairs(t) do
        if t[i] == nil then
            return false
        end
        i = i + 1
    end
    return true
end

local function size(t)
    local n = 0
    for _ in pairs(t) do
        n = n + 1
    end
    return n
end

local function repr(x, depth)
    -- Create a string representation of the object that is close to the lua code that will
    -- reproduce the object (similar to Python's "repr" function)
    depth = depth or 10
    if depth == 0 then return "..." end
    depth = depth - 1
    local x_type = type(x)
    if x_type == 'table' then
        if getmetatable(x) then
            -- If this object has a weird metatable, then don't pretend like it's a regular table
            return tostring(x)
        else
            local ret = {}
            local i = 1
            for k, v in pairs(x) do
                if k == i then
                    ret[#ret+1] = repr(x[i], depth)
                    i = i + 1
                elseif type(k) == 'string' and match(k,"[_a-zA-Z][_a-zA-Z0-9]*") then
                    ret[#ret+1] = k.."= "..repr(v,depth)
                else
                    ret[#ret+1] = "["..repr(k,depth).."]= "..repr(v,depth)
                end
            end
            return "{"..table.concat(ret, ", ").."}"
        end
    elseif x_type == 'string' then
        local escaped = gsub(x, "\\", "\\\\")
        escaped = gsub(escaped, "\n", "\\n")
        escaped = gsub(escaped, '"', '\\"')
        escaped = gsub(escaped, "[%c%z]", function(c) return "\\"..c:byte() end)
        return '"'..escaped..'"'
    else
        return tostring(x)
    end
end

local function stringify(x)
    if type(x) == 'string' then
        return x
    else
        return repr(x)
    end
end

local function split(str, sep)
    if sep == nil then
        sep = "%s"
    end
    local ret = {}
    for chunk in gmatch(str, "[^"..sep.."]+") do
        ret[#ret+1] = chunk
    end
    return ret
end

local function remove_from_list(list, item)
    local deleted, N = 0, #list
    for i=1,N do
        if list[i] == item then
            deleted = deleted + 1
        else
            list[i-deleted] = list[i]
        end
    end
    for i=N-deleted+1,N do list[i] = nil end
end

local function accumulate(co)
    local bits = {}
    for bit in coroutine.wrap(co) do
        bits[#bits+1] = bit
    end
    return bits
end

local function nth_to_last(list, n)
    return list[#list - n + 1]
end

local function keys(t)
    local ret = {}
    for k in pairs(t) do
        ret[#ret+1] = k
    end
    return ret
end

local function values(t)
    local ret = {}
    for _,v in pairs(t) do
        ret[#ret+1] = v
    end
    return ret
end

local function set(list)
    local ret = {}
    for i=1,#list do
        ret[list[i]] = true
    end
    return ret
end

local function deduplicate(list)
    local seen, deleted = {}, 0
    for i, item in ipairs(list) do
        if seen[item] then
            deleted = deleted + 1
        else
            seen[item] = true
            list[i-deleted] = list[i]
        end
    end
end

local function sum(t)
    local tot = 0
    for i=1,#t do
        tot = tot + t[i]
    end
    return tot
end

local function product(t)
    if #t > 5 and 0 < t[1] and t[1] < 1 then
        local log, log_prod = math.log, 0
        for i=1,#t do
            log_prod = log_prod + log(t[i])
        end
        return math.exp(log_prod)
    else
        local prod = 1
        for i=1,#t do
            prod = prod * t[i]
        end
        return prod
    end
end

local function all(t)
    for i=1,#t do
        if not t[i] then return false end
    end
    return true
end

local function any(t)
    for i=1,#t do
        if t[i] then return true end
    end
    return false
end

local function min(list, keyFn)
    if keyFn == nil then
        keyFn = (function(x)
            return x
        end)
    end
    if type(keyFn) == 'table' then
        local keyTable = keyFn
        keyFn = function(k)
            return keyTable[k]
        end
    end
    local best = list[1]
    local bestKey = keyFn(best)
    for i = 2, #list do
        local key = keyFn(list[i])
        if key < bestKey then
            best, bestKey = list[i], key
        end
    end
    return best
end

local function max(list, keyFn)
    if keyFn == nil then
        keyFn = (function(x)
            return x
        end)
    end
    if type(keyFn) == 'table' then
        local keyTable = keyFn
        keyFn = function(k)
            return keyTable[k]
        end
    end
    local best = list[1]
    local bestKey = keyFn(best)
    for i = 2, #list do
        local key = keyFn(list[i])
        if key > bestKey then
            best, bestKey = list[i], key
        end
    end
    return best
end

local function sort(list, keyFn, reverse)
    if keyFn == nil then
        keyFn = (function(x)
            return x
        end)
    end
    if reverse == nil then
        reverse = false
    end
    if type(keyFn) == 'table' then
        local keyTable = keyFn
        keyFn = function(k)
            return keyTable[k]
        end
    end
    table.sort(list, reverse
        and (function(x,y) return keyFn(x) > keyFn(y) end)
        or  (function(x,y) return keyFn(x) < keyFn(y) end))
    return list
end

local function equivalent(x, y, depth)
    depth = depth or 0
    if rawequal(x, y) then
        return true
    end
    if type(x) ~= type(y) then
        return false
    end
    if type(x) ~= 'table' then return false end
    if getmetatable(x) ~= getmetatable(y) then
        return false
    end
    if depth >= 99 then
        error("Exceeded maximum comparison depth")
    end
    local checked = {}
    for k, v in pairs(x) do
        if not equivalent(y[k], v, depth + 1) then
            return false
        end
        checked[k] = true
    end
    for k, v in pairs(y) do
        if not checked[k] and not equivalent(x[k], v, depth + 1) then
            return false
        end
    end
    return true
end

local function key_for(t, value)
    for k, v in pairs(t) do
        if v == value then
            return k
        end
    end
    return nil
end

local function clamp(x, min, max)
    if x < min then
        return min
    elseif x > max then
        return max
    else
        return x
    end
end

local function mix(min, max, amount)
    return (1 - amount) * min + amount * max
end

local function sign(x)
    if x == 0 then
        return 0
    elseif x < 0 then
        return -1
    else
        return 1
    end
end

local function round(x, increment)
    if increment == nil then
        increment = 1
    end
    if x >= 0 then
        return math.floor(x / increment + .5) * increment
    else
        return math.ceil(x / increment - .5) * increment
    end
end

return {is_list=is_list, size=size, repr=repr, stringify=stringify, split=split,
    remove_from_list=remove_from_list, accumulate=accumulate, nth_to_last=nth_to_last,
    keys=keys, values=values, set=set, deduplicate=deduplicate, sum=sum, product=product,
    all=all, any=any, min=min, max=max, sort=sort, equivalent=equivalent, key_for=key_for,
    clamp=clamp, mix=mix, round=round}