From b859e643fc7ec957b0f85d77997fbbfc81709306 Mon Sep 17 00:00:00 2001 From: Bruce Hill Date: Mon, 18 Dec 2017 16:19:56 -0800 Subject: [PATCH] Optimized utils and cleaned up a bit. --- nomsu.moon | 22 +-- utils.lua | 509 ++++++++++++++++++++++++----------------------------- utils.moon | 178 ------------------- 3 files changed, 242 insertions(+), 467 deletions(-) delete mode 100644 utils.moon diff --git a/nomsu.moon b/nomsu.moon index af66123..ed9b6ca 100755 --- a/nomsu.moon +++ b/nomsu.moon @@ -13,8 +13,8 @@ re = require 're' lpeg = require 'lpeg' -utils = require 'utils' -repr = utils.repr +utils = require 'utils2' +{:repr, :stringify, :min, :max, :equivalent, :set, :is_list, :sum} = utils colors = setmetatable({}, {__index:->""}) colored = setmetatable({}, {__index:(_,color)-> ((msg)-> colors[color]..msg..colors.reset)}) {:insert, :remove, :concat} = table @@ -220,7 +220,7 @@ class NomsuCompiler @debug = false @utils = utils @repr = (...)=> repr(...) - @stringify = (...)=> utils.stringify(...) + @stringify = (...)=> stringify(...) if not parent @initialize_core! @@ -252,10 +252,10 @@ class NomsuCompiler for i=1,#arg_names-1 do for j=i+1,#arg_names if arg_names[i] == arg_names[j] then @error "Duplicate argument in function #{stub}: '#{arg_names[i]}'" if canonical_args - assert utils.equivalent(utils.set(arg_names), canonical_args), "Mismatched args" - else canonical_args = utils.set(arg_names) + assert equivalent(set(arg_names), canonical_args), "Mismatched args" + else canonical_args = set(arg_names) if canonical_escaped_args - assert utils.equivalent(escaped_args, canonical_escaped_args), "Mismatched escaped args" + assert equivalent(escaped_args, canonical_escaped_args), "Mismatched escaped args" else canonical_escaped_args = escaped_args def.escaped_args = escaped_args @@ -591,7 +591,7 @@ end);]])\format(concat(buffer, "\n")) when "number" return repr(value) when "table" - if utils.is_list(value) + if is_list(value) return "[#{concat [@value_to_nomsu(v) for v in *value], ", "}]" else return "(d{#{concat ["#{@value_to_nomsu(k)}=#{@value_to_nomsu(v)}" for k,v in pairs(value)], "; "}})" @@ -801,7 +801,7 @@ end)]])\format(concat(lua_bits, "\n")) x = x\gsub("%s+"," ")\gsub("^%s*","")\gsub("%s*$","") stub = x\gsub("%%%S+","%%")\gsub("\\","") arg_names = [arg for arg in x\gmatch("%%([^%s]*)")] - escaped_args = utils.set [arg for arg in x\gmatch("\\%%([^%s]*)")] + escaped_args = set [arg for arg in x\gmatch("\\%%([^%s]*)")] return stub, arg_names, escaped_args if type(x) != 'table' @error "Invalid type for getting stub: #{type(x)} for:\n#{repr x}" @@ -832,13 +832,13 @@ end)]])\format(concat(lua_bits, "\n")) if msg error_msg ..= "\n" .. (colored.bright colored.yellow colored.onred msg) error_msg ..= "\nCallstack:" - maxlen = utils.max([#c[2] for c in *@callstack when c != "#macro"]) + maxlen = max([#c[2] for c in *@callstack when c != "#macro"]) for i=#@callstack,1,-1 if @callstack[i] != "#macro" line_no = @callstack[i][2] if line_no nums = [tonumber(n) for n in line_no\gmatch(":([0-9]+)")] - line_no = line_no\gsub(":.*$", ":#{utils.sum(nums) - #nums + 1}") + line_no = line_no\gsub(":.*$", ":#{sum(nums) - #nums + 1}") error_msg ..= "\n #{"%-#{maxlen}s"\format line_no}| #{@callstack[i][1]}" error_msg ..= "\n " @callstack = {} @@ -914,7 +914,7 @@ if arg flag <- "-c" / "-i" / "-p" / "-O" / "--help" / "-h" input <- "-" / [^;]+ output <- "-" / [^;]+ - ]], {set: utils.set}) + ]], {:set}) args = concat(arg, ";")..";" args = parser\match(args) or {} if not args or not args.flags or args.flags["--help"] or args.flags["-h"] diff --git a/utils.lua b/utils.lua index 19ba20a..46e82f8 100644 --- a/utils.lua +++ b/utils.lua @@ -1,361 +1,314 @@ -local utils -utils = { - is_list = function(t) + +local function is_list(t) if type(t) ~= 'table' then - return false + return false end local i = 1 for _ in pairs(t) do - if t[i] == nil then - return false - end - i = i + 1 + if t[i] == nil then + return false + end + i = i + 1 end return true - end, - size = function(t) - do - local n = 0 - for _ in pairs(t) do +end + +local function size(t) + local n = 0 + for _ in pairs(t) do n = n + 1 - end - return n end - end, - repr = function(x) - local _exp_0 = type(x) - if 'table' == _exp_0 then - local mt = getmetatable(x) - if mt and mt.__tostring then - return mt.__tostring(x) - elseif utils.is_list(x) then - return "{" .. tostring(table.concat((function() - local _accum_0 = { } - local _len_0 = 1 - for _index_0 = 1, #x do - local i = x[_index_0] - _accum_0[_len_0] = utils.repr(i) - _len_0 = _len_0 + 1 - end - return _accum_0 - end)(), ", ")) .. "}" - else - return "{" .. tostring(table.concat((function() - local _accum_0 = { } - local _len_0 = 1 - for k, v in pairs(x) do - _accum_0[_len_0] = "[" .. tostring(utils.repr(k)) .. "]= " .. tostring(utils.repr(v)) - _len_0 = _len_0 + 1 - end - return _accum_0 - end)(), ", ")) .. "}" - end - elseif 'string' == _exp_0 then - if x == "\n" then - return "'\\n'" - elseif not x:find([["]]) and not x:find("\n") and not x:find("\\") then - return "\"" .. x .. "\"" - elseif not x:find([[']]) and not x:find("\n") and not x:find("\\") then - return "\'" .. x .. "\'" - else - for i = 0, math.huge do - local eq = ("="):rep(i) - if not x:find("%]" .. tostring(eq) .. "%]") and not x:match(".*]" .. tostring(eq) .. "$") then - if x:sub(1, 1) == "\n" then - return "[" .. tostring(eq) .. "[\n" .. x .. "]" .. tostring(eq) .. "]" - else - return "[" .. tostring(eq) .. "[" .. x .. "]" .. tostring(eq) .. "]" + return n +end + +local function repr(x) + local x_type = type(x) + if x_type == 'table' then + local mt = getmetatable(x) + if mt and mt.__tostring then + return mt.__tostring(x) + elseif is_list(x) then + local ret = {} + for i=1,#x do + ret[i] = repr(x[i]) + end + return "{"..table.concat(ret, ", ").."}" + else + local ret = {} + for k, v in pairs(x) do + ret[#ret+1] = "["..repr(k).."]= "..repr(v) + end + return "{"..table.concat(ret, ", ").."}" + end + elseif x_type == 'string' then + if x == "\n" then + return "'\\n'" + elseif not x:find([["]]) and not x:find("\n") and not x:find("\\") then + return "\"" .. x .. "\"" + elseif not x:find([[']]) and not x:find("\n") and not x:find("\\") then + return "\'" .. x .. "\'" + else + for i = 0, math.huge do + local eq = ("="):rep(i) + if not x:find("%]"..eq.."%]") and x:sub(-#eq-1, -1) ~= "]"..eq then + if x:sub(1, 1) == "\n" then + return "["..eq.."[\n"..x.."]"..eq.."]" + else + return "["..eq.."["..x.."]"..eq.."]" + end + end end - end end - end else - return tostring(x) + return tostring(x) end - end, - stringify = function(x) +end + +local function stringify(x) if type(x) == 'string' then - return x + return x else - return utils.repr(x) + return repr(x) end - end, - split = function(str, sep) +end + +local function split(str, sep) if sep == nil then - sep = "%s" + sep = "%s" end - local _accum_0 = { } - local _len_0 = 1 - for chunk in str:gmatch("[^" .. tostring(sep) .. "]+") do - _accum_0[_len_0] = chunk - _len_0 = _len_0 + 1 + local ret = {} + for chunk in str:gmatch("[^"..sep.."]+") do + ret[#ret+1] = chunk end - return _accum_0 - end, - remove_from_list = function(list, item) - for i, list_item in ipairs(list) do - if list_item == item then - table.remove(list, i) - return - end + return ret +end + +local function remove_from_list(list, item) + for i=1,#list do + if list[i] == item then + table.remove(list, i) + return + end end - end, - accumulate = function(glue, co) +end + +local function accumulate(glue, co) if co == nil then - glue, co = "", glue + glue, co = "", glue end local bits = { } for bit in coroutine.wrap(co) do - table.insert(bits, bit) + bits[#bits+1] = bit end return table.concat(bits, glue) - end, - range = function(start, stop, step) - if stop == nil then - start, stop, step = 1, start, 1 - elseif step == nil then - step = 1 - elseif step == 0 then - error("Range step cannot be zero.") - end - return setmetatable({ - start = start, - stop = stop, - step = step - }, { - __ipairs = function(self) - local iter - iter = function(self, i) - if i <= (self.stop - self.start) / self.step then - return i + 1, self.start + i * self.step - end - end - return iter, self, 0 - end, - __index = function(self, i) - if type(i) ~= "Number" then - return nil - end - if i % 1 ~= 0 then - return nil - end - if i <= 0 or i - 1 > (self.stop - self.start) / self.step then - return nil - end - return self.start + (i - 1) * self.step - end, - __len = function(self) - local len = (self.stop - self.start) / self.step - if len < 0 then - len = 0 - end - return len - end - }) - end, - nth_to_last = function(list, n) +end + +local function nth_to_last(list, n) return list[#list - n + 1] - end, - keys = function(t) - local _accum_0 = { } - local _len_0 = 1 +end + +local function keys(t) + local ret = {} for k in pairs(t) do - _accum_0[_len_0] = k - _len_0 = _len_0 + 1 + ret[#ret+1] = k end - return _accum_0 - end, - values = function(t) - local _accum_0 = { } - local _len_0 = 1 - for _, v in pairs(t) do - _accum_0[_len_0] = v - _len_0 = _len_0 + 1 + return ret +end + +local function values(t) + local ret = {} + for _,v in pairs(t) do + ret[#ret+1] = v end - return _accum_0 - end, - set = function(list) - local _tbl_0 = { } - for _index_0 = 1, #list do - local i = list[_index_0] - _tbl_0[i] = true + return ret +end + +local function set(list) + local ret = {} + for i=1,#list do + ret[list[i]] = true end - return _tbl_0 - end, - sum = function(t) - do - local tot = 0 - for _, x in pairs(t) do - tot = tot + x - end - return tot + return ret +end + +local function sum(t) + local tot = 0 + for i=1,#t do + tot = tot + t[i] end - end, - product = function(t) - do - local prod = 1 - for _, x in pairs(t) do - prod = prod * x - end - return prod + return tot +end + +local function product(t) + if #t > 5 and 0 < t[1] and t[1] < 1 then + local log, log_prod = math.log, 0 + for i=1,#t do + log_prod = log_prod + log(t[i]) + end + return math.exp(log_prod) + else + local prod = 1 + for i=1,#t do + prod = prod * t[i] + end + return prod end - end, - all = function(t) - for _, x in pairs(t) do - if not x then - return false - end +end + +local function all(t) + for i=1,#t do + if not t[i] then return false end end return true - end, - any = function(t) - for _, x in pairs(t) do - if x then - return true - end +end + +local function any(t) + for i=1,#t do + if t[i] then return true end end return false - end, - min = function(list, keyFn) +end + +local function min(list, keyFn) if keyFn == nil then - keyFn = (function(x) - return x - end) + keyFn = (function(x) + return x + end) end - assert(utils.is_list(list), "min() expects to be operating on a list") - do - local best = list[1] - if type(keyFn) == 'table' then + if type(keyFn) == 'table' then local keyTable = keyFn keyFn = function(k) - return keyTable[k] + return keyTable[k] end - end - for i = 2, #list do - if keyFn(list[i]) < keyFn(best) then - best = list[i] - end - end - return best end - end, - max = function(list, keyFn) + local best, bestKey = list[1], keyFn(best) + for i = 2, #list do + local key = keyFn(list[i]) + if key < bestKey then + best, bestKey = list[i], key + end + end + return best +end + +local function max(list, keyFn) if keyFn == nil then - keyFn = (function(x) - return x - end) + keyFn = (function(x) + return x + end) end - assert(utils.is_list(list), "min() expects to be operating on a list") - do - local best = list[1] - if type(keyFn) == 'table' then + if type(keyFn) == 'table' then local keyTable = keyFn keyFn = function(k) - return keyTable[k] + return keyTable[k] end - end - for i = 2, #list do - if keyFn(list[i]) > keyFn(best) then - best = list[i] - end - end - return best end - end, - sort = function(list, keyFn, reverse) + local best, bestKey = list[1], keyFn(best) + for i = 2, #list do + local key = keyFn(list[i]) + if key > bestKey then + best, bestKey = list[i], key + end + end + return best +end + +local function sort(list, keyFn, reverse) if keyFn == nil then - keyFn = (function(x) - return x - end) + keyFn = (function(x) + return x + end) end if reverse == nil then - reverse = false + reverse = false end - assert(utils.is_list(list), "min() expects to be operating on a list") if type(keyFn) == 'table' then - local keyTable = keyFn - keyFn = function(k) - return keyTable[k] - end - end - local comparison - if reverse then - comparison = (function(x, y) - return (keyFn(x) > keyFn(y)) - end) - else - comparison = (function(x, y) - return (keyFn(x) < keyFn(y)) - end) - end - return table.sort(list, comparison) - end, - equivalent = function(x, y, depth) - if depth == nil then - depth = 1 + local keyTable = keyFn + keyFn = function(k) + return keyTable[k] + end end + table.sort(list, reverse + and (function(x,y) return keyFn(x) > keyFn(y) end) + or (function(x,y) return keyFn(x) < keyFn(y) end)) + return list +end + +local function equivalent(x, y, depth) + depth = depth or 1 if x == y then - return true + return true end if type(x) ~= type(y) then - return false + return false end if type(x) ~= 'table' then - return false + return false end if depth == 0 then - return false - end - for k, v in pairs(x) do - if not (utils.equivalent(y[k], v, depth - 1)) then return false - end + end + local checked = {} + for k, v in pairs(x) do + if not equivalent(y[k], v, depth - 1) then + return false + end + checked[k] = true end for k, v in pairs(y) do - if not (utils.equivalent(x[k], v, depth - 1)) then - return false - end + if not checked[k] and not equivalent(x[k], v, depth - 1) then + return false + end end return true - end, - key_for = function(t, value) +end + +local function key_for(t, value) for k, v in pairs(t) do - if v == value then - return k - end + if v == value then + return k + end end return nil - end, - clamp = function(x, min, max) +end + +local function clamp(x, min, max) if x < min then - return min + return min elseif x > max then - return max + return max else - return x + return x end - end, - mix = function(min, max, amount) +end + +local function mix(min, max, amount) return (1 - amount) * min + amount * max - end, - sign = function(x) +end + +local function sign(x) if x == 0 then - return 0 + return 0 elseif x < 0 then - return -1 + return -1 else - return 1 + return 1 end - end, - round = function(x, increment) +end + +local function round(x, increment) if increment == nil then - increment = 1 + increment = 1 end if x >= 0 then - return math.floor(x / increment + .5) * increment + return math.floor(x / increment + .5) * increment else - return math.ceil(x / increment - .5) * increment + return math.ceil(x / increment - .5) * increment end - end -} -return utils +end + +return {is_list=is_list, size=size, repr=repr, stringify=stringify, split=split, + remove_from_list=remove_from_list, accumulate=accumulate, nth_to_last=nth_to_last, + keys=keys, values=values, set=set, sum=sum, product=product, all=all, any=any, + min=min, max=max, sort=sort, equivalent=equivalent, key_for=key_for, clamp=clamp, + mix=mix, round=round} diff --git a/utils.moon b/utils.moon deleted file mode 100644 index 7d1e39e..0000000 --- a/utils.moon +++ /dev/null @@ -1,178 +0,0 @@ -local utils -utils = { - is_list: (t)-> - if type(t) != 'table' then return false - i = 1 - for _ in pairs(t) - if t[i] == nil then return false - i += 1 - return true - - size: (t)-> - with n = 0 - for _ in pairs(t) do n += 1 - - repr: (x)-> - switch type(x) - when 'table' - mt = getmetatable(x) - if mt and mt.__tostring - mt.__tostring(x) - elseif utils.is_list x - "{#{table.concat([utils.repr(i) for i in *x], ", ")}}" - else - "{#{table.concat(["[#{utils.repr(k)}]= #{utils.repr(v)}" for k,v in pairs x], ", ")}}" - when 'string' - if x == "\n" - return "'\\n'" - elseif not x\find[["]] and not x\find"\n" and not x\find"\\" - "\""..x.."\"" - elseif not x\find[[']] and not x\find"\n" and not x\find"\\" - "\'"..x.."\'" - else - for i=0,math.huge - eq = ("=")\rep(i) - if not x\find"%]#{eq}%]" and not x\match(".*]#{eq}$") - -- Stupid bullshit add an extra newline because lua discards first one if it exists - if x\sub(1,1) == "\n" - return "[#{eq}[\n"..x.."]#{eq}]" - else - return "[#{eq}["..x.."]#{eq}]" - else - tostring(x) - - stringify: (x)-> - if type(x) == 'string' then x - else utils.repr(x) - - split: (str, sep="%s")-> - [chunk for chunk in str\gmatch("[^#{sep}]+")] - - remove_from_list: (list, item)-> - for i,list_item in ipairs(list) - if list_item == item - table.remove list, i - return - - accumulate: (glue, co)-> - if co == nil then glue, co = "", glue - bits = {} - for bit in coroutine.wrap(co) - table.insert(bits, bit) - return table.concat(bits, glue) - - range: (start,stop,step)-> - if stop == nil - start,stop,step = 1,start,1 - elseif step == nil - step = 1 - elseif step == 0 - error("Range step cannot be zero.") - return setmetatable({:start,:stop,:step}, { - __ipairs: => - iter = (i)=> - if i <= (@stop-@start)/@step - return i+1, @start+i*@step - return iter, @, 0 - __index: (i)=> - if type(i) != "Number" then return nil - if i % 1 != 0 then return nil - if i <= 0 or i-1 > (@stop-@start)/@step then return nil - return @start + (i-1)*@step - __len: => - len = (@stop-@start)/@step - if len < 0 then len = 0 - return len - - }) - - nth_to_last: (list, n) -> list[#list-n+1] - - keys: (t)-> [k for k in pairs(t)] - values: (t)-> [v for _,v in pairs(t)] - set: (list)-> {i,true for i in *list} - - sum: (t)-> - with tot = 0 - for _,x in pairs(t) do tot += x - - product: (t)-> - with prod = 1 - for _,x in pairs(t) do prod *= x - - all: (t)-> - for _,x in pairs t - if not x then return false - return true - - any: (t)-> - for _,x in pairs t - if x then return true - return false - - min: (list, keyFn=((x)->x))-> - assert utils.is_list(list), "min() expects to be operating on a list" - with best = list[1] - if type(keyFn) == 'table' - keyTable = keyFn - keyFn = (k)->keyTable[k] - for i=2,#list - if keyFn(list[i]) < keyFn(best) - best = list[i] - - max: (list, keyFn=((x)->x))-> - assert utils.is_list(list), "min() expects to be operating on a list" - with best = list[1] - if type(keyFn) == 'table' - keyTable = keyFn - keyFn = (k)->keyTable[k] - for i=2,#list - if keyFn(list[i]) > keyFn(best) - best = list[i] - - sort: (list, keyFn=((x)->x), reverse=false)-> - assert utils.is_list(list), "min() expects to be operating on a list" - if type(keyFn) == 'table' - keyTable = keyFn - keyFn = (k)->keyTable[k] - comparison = if reverse then ((x,y)->(keyFn(x)>keyFn(y))) else ((x,y)->(keyFn(x) - if x == y then return true - if type(x) != type(y) then return false - if type(x) != 'table' then return false - if depth == 0 then return false - for k,v in pairs(x) - unless utils.equivalent(y[k], v, depth-1) - return false - for k,v in pairs(y) - unless utils.equivalent(x[k], v, depth-1) - return false - return true - - key_for: (t, value)-> - for k,v in pairs(t) - if v == value - return k - return nil - - clamp: (x, min,max)-> - if x < min then min - elseif x > max then max - else x - - mix: (min,max, amount)-> - (1-amount)*min + amount*max - - sign: (x)-> - if x == 0 then 0 - elseif x < 0 then -1 - else 1 - - round: (x, increment=1)-> - if x >= 0 then math.floor(x/increment + .5)*increment - else math.ceil(x/increment - .5)*increment - -} -return utils