From b859e643fc7ec957b0f85d77997fbbfc81709306 Mon Sep 17 00:00:00 2001 From: Bruce Hill Date: Mon, 18 Dec 2017 16:19:56 -0800 Subject: Optimized utils and cleaned up a bit. --- utils.lua | 523 ++++++++++++++++++++++++++++---------------------------------- 1 file changed, 238 insertions(+), 285 deletions(-) (limited to 'utils.lua') diff --git a/utils.lua b/utils.lua index 19ba20a..46e82f8 100644 --- a/utils.lua +++ b/utils.lua @@ -1,361 +1,314 @@ -local utils -utils = { - is_list = function(t) + +local function is_list(t) if type(t) ~= 'table' then - return false + return false end local i = 1 for _ in pairs(t) do - if t[i] == nil then - return false - end - i = i + 1 + if t[i] == nil then + return false + end + i = i + 1 end return true - end, - size = function(t) - do - local n = 0 - for _ in pairs(t) do +end + +local function size(t) + local n = 0 + for _ in pairs(t) do n = n + 1 - end - return n - end - end, - repr = function(x) - local _exp_0 = type(x) - if 'table' == _exp_0 then - local mt = getmetatable(x) - if mt and mt.__tostring then - return mt.__tostring(x) - elseif utils.is_list(x) then - return "{" .. tostring(table.concat((function() - local _accum_0 = { } - local _len_0 = 1 - for _index_0 = 1, #x do - local i = x[_index_0] - _accum_0[_len_0] = utils.repr(i) - _len_0 = _len_0 + 1 - end - return _accum_0 - end)(), ", ")) .. "}" - else - return "{" .. tostring(table.concat((function() - local _accum_0 = { } - local _len_0 = 1 - for k, v in pairs(x) do - _accum_0[_len_0] = "[" .. tostring(utils.repr(k)) .. "]= " .. tostring(utils.repr(v)) - _len_0 = _len_0 + 1 - end - return _accum_0 - end)(), ", ")) .. "}" - end - elseif 'string' == _exp_0 then - if x == "\n" then - return "'\\n'" - elseif not x:find([["]]) and not x:find("\n") and not x:find("\\") then - return "\"" .. x .. "\"" - elseif not x:find([[']]) and not x:find("\n") and not x:find("\\") then - return "\'" .. x .. "\'" - else - for i = 0, math.huge do - local eq = ("="):rep(i) - if not x:find("%]" .. tostring(eq) .. "%]") and not x:match(".*]" .. tostring(eq) .. "$") then - if x:sub(1, 1) == "\n" then - return "[" .. tostring(eq) .. "[\n" .. x .. "]" .. tostring(eq) .. "]" - else - return "[" .. tostring(eq) .. "[" .. x .. "]" .. tostring(eq) .. "]" + end + return n +end + +local function repr(x) + local x_type = type(x) + if x_type == 'table' then + local mt = getmetatable(x) + if mt and mt.__tostring then + return mt.__tostring(x) + elseif is_list(x) then + local ret = {} + for i=1,#x do + ret[i] = repr(x[i]) + end + return "{"..table.concat(ret, ", ").."}" + else + local ret = {} + for k, v in pairs(x) do + ret[#ret+1] = "["..repr(k).."]= "..repr(v) + end + return "{"..table.concat(ret, ", ").."}" + end + elseif x_type == 'string' then + if x == "\n" then + return "'\\n'" + elseif not x:find([["]]) and not x:find("\n") and not x:find("\\") then + return "\"" .. x .. "\"" + elseif not x:find([[']]) and not x:find("\n") and not x:find("\\") then + return "\'" .. x .. "\'" + else + for i = 0, math.huge do + local eq = ("="):rep(i) + if not x:find("%]"..eq.."%]") and x:sub(-#eq-1, -1) ~= "]"..eq then + if x:sub(1, 1) == "\n" then + return "["..eq.."[\n"..x.."]"..eq.."]" + else + return "["..eq.."["..x.."]"..eq.."]" + end + end end - end end - end else - return tostring(x) + return tostring(x) end - end, - stringify = function(x) +end + +local function stringify(x) if type(x) == 'string' then - return x + return x else - return utils.repr(x) + return repr(x) end - end, - split = function(str, sep) +end + +local function split(str, sep) if sep == nil then - sep = "%s" - end - local _accum_0 = { } - local _len_0 = 1 - for chunk in str:gmatch("[^" .. tostring(sep) .. "]+") do - _accum_0[_len_0] = chunk - _len_0 = _len_0 + 1 - end - return _accum_0 - end, - remove_from_list = function(list, item) - for i, list_item in ipairs(list) do - if list_item == item then - table.remove(list, i) - return - end - end - end, - accumulate = function(glue, co) + sep = "%s" + end + local ret = {} + for chunk in str:gmatch("[^"..sep.."]+") do + ret[#ret+1] = chunk + end + return ret +end + +local function remove_from_list(list, item) + for i=1,#list do + if list[i] == item then + table.remove(list, i) + return + end + end +end + +local function accumulate(glue, co) if co == nil then - glue, co = "", glue + glue, co = "", glue end local bits = { } for bit in coroutine.wrap(co) do - table.insert(bits, bit) + bits[#bits+1] = bit end return table.concat(bits, glue) - end, - range = function(start, stop, step) - if stop == nil then - start, stop, step = 1, start, 1 - elseif step == nil then - step = 1 - elseif step == 0 then - error("Range step cannot be zero.") - end - return setmetatable({ - start = start, - stop = stop, - step = step - }, { - __ipairs = function(self) - local iter - iter = function(self, i) - if i <= (self.stop - self.start) / self.step then - return i + 1, self.start + i * self.step - end - end - return iter, self, 0 - end, - __index = function(self, i) - if type(i) ~= "Number" then - return nil - end - if i % 1 ~= 0 then - return nil - end - if i <= 0 or i - 1 > (self.stop - self.start) / self.step then - return nil - end - return self.start + (i - 1) * self.step - end, - __len = function(self) - local len = (self.stop - self.start) / self.step - if len < 0 then - len = 0 - end - return len - end - }) - end, - nth_to_last = function(list, n) +end + +local function nth_to_last(list, n) return list[#list - n + 1] - end, - keys = function(t) - local _accum_0 = { } - local _len_0 = 1 +end + +local function keys(t) + local ret = {} for k in pairs(t) do - _accum_0[_len_0] = k - _len_0 = _len_0 + 1 - end - return _accum_0 - end, - values = function(t) - local _accum_0 = { } - local _len_0 = 1 - for _, v in pairs(t) do - _accum_0[_len_0] = v - _len_0 = _len_0 + 1 - end - return _accum_0 - end, - set = function(list) - local _tbl_0 = { } - for _index_0 = 1, #list do - local i = list[_index_0] - _tbl_0[i] = true - end - return _tbl_0 - end, - sum = function(t) - do - local tot = 0 - for _, x in pairs(t) do - tot = tot + x - end - return tot - end - end, - product = function(t) - do - local prod = 1 - for _, x in pairs(t) do - prod = prod * x - end - return prod - end - end, - all = function(t) - for _, x in pairs(t) do - if not x then - return false - end + ret[#ret+1] = k + end + return ret +end + +local function values(t) + local ret = {} + for _,v in pairs(t) do + ret[#ret+1] = v + end + return ret +end + +local function set(list) + local ret = {} + for i=1,#list do + ret[list[i]] = true + end + return ret +end + +local function sum(t) + local tot = 0 + for i=1,#t do + tot = tot + t[i] + end + return tot +end + +local function product(t) + if #t > 5 and 0 < t[1] and t[1] < 1 then + local log, log_prod = math.log, 0 + for i=1,#t do + log_prod = log_prod + log(t[i]) + end + return math.exp(log_prod) + else + local prod = 1 + for i=1,#t do + prod = prod * t[i] + end + return prod + end +end + +local function all(t) + for i=1,#t do + if not t[i] then return false end end return true - end, - any = function(t) - for _, x in pairs(t) do - if x then - return true - end +end + +local function any(t) + for i=1,#t do + if t[i] then return true end end return false - end, - min = function(list, keyFn) +end + +local function min(list, keyFn) if keyFn == nil then - keyFn = (function(x) - return x - end) + keyFn = (function(x) + return x + end) end - assert(utils.is_list(list), "min() expects to be operating on a list") - do - local best = list[1] - if type(keyFn) == 'table' then + if type(keyFn) == 'table' then local keyTable = keyFn keyFn = function(k) - return keyTable[k] + return keyTable[k] end - end - for i = 2, #list do - if keyFn(list[i]) < keyFn(best) then - best = list[i] + end + local best, bestKey = list[1], keyFn(best) + for i = 2, #list do + local key = keyFn(list[i]) + if key < bestKey then + best, bestKey = list[i], key end - end - return best end - end, - max = function(list, keyFn) + return best +end + +local function max(list, keyFn) if keyFn == nil then - keyFn = (function(x) - return x - end) + keyFn = (function(x) + return x + end) end - assert(utils.is_list(list), "min() expects to be operating on a list") - do - local best = list[1] - if type(keyFn) == 'table' then + if type(keyFn) == 'table' then local keyTable = keyFn keyFn = function(k) - return keyTable[k] + return keyTable[k] end - end - for i = 2, #list do - if keyFn(list[i]) > keyFn(best) then - best = list[i] + end + local best, bestKey = list[1], keyFn(best) + for i = 2, #list do + local key = keyFn(list[i]) + if key > bestKey then + best, bestKey = list[i], key end - end - return best end - end, - sort = function(list, keyFn, reverse) + return best +end + +local function sort(list, keyFn, reverse) if keyFn == nil then - keyFn = (function(x) - return x - end) + keyFn = (function(x) + return x + end) end if reverse == nil then - reverse = false + reverse = false end - assert(utils.is_list(list), "min() expects to be operating on a list") if type(keyFn) == 'table' then - local keyTable = keyFn - keyFn = function(k) - return keyTable[k] - end - end - local comparison - if reverse then - comparison = (function(x, y) - return (keyFn(x) > keyFn(y)) - end) - else - comparison = (function(x, y) - return (keyFn(x) < keyFn(y)) - end) - end - return table.sort(list, comparison) - end, - equivalent = function(x, y, depth) - if depth == nil then - depth = 1 + local keyTable = keyFn + keyFn = function(k) + return keyTable[k] + end end + table.sort(list, reverse + and (function(x,y) return keyFn(x) > keyFn(y) end) + or (function(x,y) return keyFn(x) < keyFn(y) end)) + return list +end + +local function equivalent(x, y, depth) + depth = depth or 1 if x == y then - return true + return true end if type(x) ~= type(y) then - return false + return false end if type(x) ~= 'table' then - return false + return false end if depth == 0 then - return false + return false end + local checked = {} for k, v in pairs(x) do - if not (utils.equivalent(y[k], v, depth - 1)) then - return false - end + if not equivalent(y[k], v, depth - 1) then + return false + end + checked[k] = true end for k, v in pairs(y) do - if not (utils.equivalent(x[k], v, depth - 1)) then - return false - end + if not checked[k] and not equivalent(x[k], v, depth - 1) then + return false + end end return true - end, - key_for = function(t, value) +end + +local function key_for(t, value) for k, v in pairs(t) do - if v == value then - return k - end + if v == value then + return k + end end return nil - end, - clamp = function(x, min, max) +end + +local function clamp(x, min, max) if x < min then - return min + return min elseif x > max then - return max + return max else - return x + return x end - end, - mix = function(min, max, amount) +end + +local function mix(min, max, amount) return (1 - amount) * min + amount * max - end, - sign = function(x) +end + +local function sign(x) if x == 0 then - return 0 + return 0 elseif x < 0 then - return -1 + return -1 else - return 1 + return 1 end - end, - round = function(x, increment) +end + +local function round(x, increment) if increment == nil then - increment = 1 + increment = 1 end if x >= 0 then - return math.floor(x / increment + .5) * increment + return math.floor(x / increment + .5) * increment else - return math.ceil(x / increment - .5) * increment - end - end -} -return utils + return math.ceil(x / increment - .5) * increment + end +end + +return {is_list=is_list, size=size, repr=repr, stringify=stringify, split=split, + remove_from_list=remove_from_list, accumulate=accumulate, nth_to_last=nth_to_last, + keys=keys, values=values, set=set, sum=sum, product=product, all=all, any=any, + min=min, max=max, sort=sort, equivalent=equivalent, key_for=key_for, clamp=clamp, + mix=mix, round=round} -- cgit v1.2.3