aboutsummaryrefslogtreecommitdiff
path: root/nomsu.lua
diff options
context:
space:
mode:
authorBruce Hill <bitbucket@bruce-hill.com>2017-09-21 13:30:59 -0700
committerBruce Hill <bitbucket@bruce-hill.com>2017-09-21 13:30:59 -0700
commit79ad6b07c1aa5f32b7d8481f2937fa02680c2937 (patch)
tree82cce82a25da24846e691c9370974940f118a06d /nomsu.lua
parentc82e4f34097f7a27d9cb5d2d254ccb6916d63694 (diff)
Optimization and cleanup pass.
Diffstat (limited to 'nomsu.lua')
-rw-r--r--nomsu.lua124
1 files changed, 36 insertions, 88 deletions
diff --git a/nomsu.lua b/nomsu.lua
index 833c929..3e2a891 100644
--- a/nomsu.lua
+++ b/nomsu.lua
@@ -74,6 +74,11 @@ do
self:writeln("Defining rule: " .. tostring(spec))
end
local invocations, arg_names = self:get_invocations(spec)
+ for i = 2, #invocations do
+ if not utils.equivalent(utils.set(arg_names[invocations[1]]), utils.set(arg_names[invocations[i]])) then
+ self:error("Conflicting argument names " .. tostring(utils.repr(invocations[1])) .. " and " .. tostring(utils.repr(invocations[i])) .. " for " .. tostring(utils.repr(spec)))
+ end
+ end
local fn_info = {
fn = fn,
arg_names = arg_names,
@@ -88,7 +93,7 @@ do
end,
get_invocations_from_definition = function(self, def, vars)
if def.type == "String" then
- return self:tree_to_value(def, vars)
+ return self.__class:unescape_string(def.value)
end
if def.type ~= "List" then
self:error("Trying to get invocations from " .. tostring(def.type) .. ", but expected List or String.")
@@ -100,7 +105,7 @@ do
repeat
local item = _list_0[_index_0]
if item.type == "String" then
- insert(invocations, self:tree_to_value(item, vars))
+ insert(invocations, item.value)
_continue_0 = true
break
end
@@ -142,7 +147,6 @@ do
end
local invocations = { }
local arg_names = { }
- local prev_arg_names = nil
for _index_0 = 1, #text do
local _text = text[_index_0]
local invocation = _text:gsub("'", " '"):gsub("%%%S+", "%%"):gsub("%s+", " ")
@@ -157,13 +161,6 @@ do
_arg_names = _accum_0
end
insert(invocations, invocation)
- if prev_arg_names then
- if not utils.equivalent(utils.set(prev_arg_names), utils.set(_arg_names)) then
- self:error("Conflicting argument names " .. tostring(utils.repr(prev_arg_names)) .. " and " .. tostring(utils.repr(_arg_names)) .. " for " .. tostring(utils.repr(text)))
- end
- else
- prev_arg_names = _arg_names
- end
arg_names[invocation] = _arg_names
end
return invocations, arg_names
@@ -173,6 +170,11 @@ do
self:writeln("DEFINING MACRO: " .. tostring(spec) .. tostring(src or ""))
end
local invocations, arg_names = self:get_invocations(spec)
+ for i = 2, #invocations do
+ if not utils.equivalent(utils.set(arg_names[invocations[1]]), utils.set(arg_names[invocations[i]])) then
+ self:error("Conflicting argument names " .. tostring(utils.repr(invocations[1])) .. " and " .. tostring(utils.repr(invocations[i])) .. " for " .. tostring(utils.repr(spec)))
+ end
+ end
local fn_info = {
fn = lua_gen_fn,
arg_names = arg_names,
@@ -185,49 +187,6 @@ do
self.defs[invocation] = fn_info
end
end,
- serialize = function(self, obj)
- local _exp_0 = type(obj)
- if "function" == _exp_0 then
- error("Function serialization is not yet implemented.")
- return "assert(load(" .. utils.repr(string.dump(obj)) .. "))"
- elseif "table" == _exp_0 then
- if utils.is_list(obj) then
- return "{" .. tostring(table.concat((function()
- local _accum_0 = { }
- local _len_0 = 1
- for _index_0 = 1, #obj do
- local i = obj[_index_0]
- _accum_0[_len_0] = self:serialize(i)
- _len_0 = _len_0 + 1
- end
- return _accum_0
- end)(), ", ")) .. "}"
- else
- return "{" .. tostring(table.concat((function()
- local _accum_0 = { }
- local _len_0 = 1
- for k, v in pairs(obj) do
- _accum_0[_len_0] = "[" .. tostring(self:serialize(k)) .. "]= " .. tostring(self:serialize(v))
- _len_0 = _len_0 + 1
- end
- return _accum_0
- end)(), ", ")) .. "}"
- end
- elseif "number" == _exp_0 then
- return utils.repr(obj)
- elseif "string" == _exp_0 then
- return utils.repr(obj)
- else
- return error("Serialization not implemented for: " .. tostring(type(obj)))
- end
- end,
- deserialize = function(self, str)
- local lua_thunk, err = load("return (function(compiler,vars)\n return " .. str .. "\n end)")
- if not lua_thunk then
- error("Failed to compile generated code:\n" .. tostring(str) .. "\n\n" .. tostring(err))
- end
- return (lua_thunk())(self, { })
- end,
parse = function(self, str, filename)
if self.debug then
self:writeln("PARSING:\n" .. tostring(str))
@@ -477,10 +436,7 @@ do
return self.__class:comma_separated_items("compiler:call(", args, ")")
end
elseif "String" == _exp_0 then
- local unescaped = tree.value:gsub("\\(.)", (function(c)
- return STRING_ESCAPES[c] or c
- end))
- return utils.repr(unescaped)
+ return utils.repr(self.__class:unescape_string(tree.value))
elseif "Longstring" == _exp_0 then
local concat_parts = { }
local string_buffer = ""
@@ -751,24 +707,13 @@ do
end
end,
initialize_core = function(self)
- local as_lua_code
- as_lua_code = function(self, str, vars)
- local _exp_0 = str.type
- if "String" == _exp_0 then
- return self:tree_to_value(str, vars)
- elseif "Longstring" == _exp_0 then
- return self:tree_to_value(str, vars)
- else
- return self:tree_to_lua(str)
- end
- end
self:defmacro([[lua block %lua_code]], function(self, vars, kind)
if kind == "Expression" then
error("Expected to be in statement.")
end
local inner_vars = setmetatable({ }, {
__index = function(_, key)
- return "vars[" .. tostring(utils.repr(key)) .. "]"
+ return error("vars[" .. tostring(utils.repr(key)) .. "]")
end
})
return "do\n" .. self:tree_to_value(vars.lua_code, inner_vars) .. "\nend", true
@@ -777,7 +722,7 @@ do
local lua_code = vars.lua_code.value
local inner_vars = setmetatable({ }, {
__index = function(_, key)
- return "vars[" .. tostring(utils.repr(key)) .. "]"
+ return error("vars[" .. tostring(utils.repr(key)) .. "]")
end
})
return self:tree_to_value(vars.lua_code, inner_vars)
@@ -828,26 +773,29 @@ do
})
_base_0.__class = _class_0
local self = _class_0
+ self.unescape_string = function(self, str)
+ return str:gsub("\\(.)", (function(c)
+ return STRING_ESCAPES[c] or c
+ end))
+ end
self.comma_separated_items = function(self, open, items, close)
- return utils.accumulate("\n", function()
- local buffer = open
- local so_far = 0
- for i, item in ipairs(items) do
- if i < #items then
- item = item .. ", "
- end
- if so_far + #item >= 80 and #buffer > 0 then
- coroutine.yield(buffer)
- so_far = so_far - #buffer
- buffer = item
- else
- so_far = so_far + #item
- buffer = buffer .. item
- end
+ local bits = {
+ open
+ }
+ local so_far = 0
+ for i, item in ipairs(items) do
+ if i < #items then
+ item = item .. ", "
+ end
+ insert(bits, item)
+ so_far = so_far + #item
+ if so_far >= 80 then
+ insert(bits, "\n")
+ so_far = 0
end
- buffer = buffer .. close
- return coroutine.yield(buffer)
- end)
+ end
+ insert(bits, close)
+ return table.concat(bits)
end
NomsuCompiler = _class_0
end