Skip to content

Treesitter polling instead of parsing on each snippet trigger #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 47 additions & 5 deletions lua/luasnip-latex-snippets/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ local no_backslash = utils.no_backslash

local M = {}

local default_opts = {
local opts = {
use_treesitter = false,
allow_on_markdown = true,
markdown_use_polling = false,
}

M.setup = function(opts)
opts = vim.tbl_deep_extend("force", default_opts, opts or {})
M.setup = function(override_opts)
override_opts = override_opts or {}
opts = vim.tbl_deep_extend("force", opts, override_opts)

local augroup = vim.api.nvim_create_augroup("luasnip-latex-snippets", {})
vim.api.nvim_create_autocmd("FileType", {
Expand All @@ -30,7 +32,7 @@ M.setup = function(opts)
group = augroup,
once = true,
callback = function()
M.setup_markdown()
M.setup_markdown(augroup)
end,
})
end
Expand Down Expand Up @@ -91,12 +93,52 @@ M.setup_tex = function(is_math, not_math)
})
end

M.setup_markdown = function()
---@param augroup integer
M.setup_markdown = function(augroup)
local ls = require("luasnip")

local is_math = utils.with_opts(utils.is_math, true)
local not_math = utils.with_opts(utils.not_math, true)

if opts.markdown_use_polling then
local p = require("luasnip-latex-snippets.util.ts_utils").polling
vim.api.nvim_create_autocmd(
{"BufEnter", "BufWinEnter"},
{
pattern = "*.md",
group = augroup,
callback = function(args)
if not p.is_buf_tracked(args.buf) then
p.init_buf(args.buf)
end
end
}
)

vim.api.nvim_create_autocmd(
"BufDelete",
{
pattern = "*.md",
group = augroup,
callback = function(args)
if p.is_buf_tracked(args.buf) then
p.deinit_buf(args.buf)
end
end
}
)

is_math = function()
local buf = vim.api.nvim_get_current_buf()
return p.tracked_bufs[buf].in_math
end

not_math = function()
local buf = vim.api.nvim_get_current_buf()
return p.tracked_bufs[buf].in_text
end
end

local math_i = require("luasnip-latex-snippets/math_i").retrieve(is_math)
ls.add_snippets("markdown", math_i, { default_priority = 0 })

Expand Down
133 changes: 131 additions & 2 deletions lua/luasnip-latex-snippets/util/ts_utils.lua
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
local M = {}
local M = { polling = {} }

local MATH_NODES = {
displayed_equation = true,
Expand All @@ -22,7 +22,7 @@ local function get_node_at_cursor()
return
end

local root_tree = parser:parse({ row, col, row, col })[1]
local root_tree = parser:parse()[1]
local root = root_tree and root_tree:root()
if not root then
return
Expand Down Expand Up @@ -65,4 +65,133 @@ function M.in_mathzone()
return false
end

--- @alias node_stack string[]
--- @class tracked_buffer
--- @field timer uv_timer_t
--- @field parser LanguageTree
--- @field tree TSTree
--- @field node_stack node_stack
--- @field in_math boolean
--- @field in_text boolean
--- @type table<buffer, tracked_buffer>
local tracked_bufs = {}

local function get_node()
local row, col = unpack(vim.api.nvim_win_get_cursor(0))
row = row - 1
col = col - 1
local buf = vim.api.nvim_get_current_buf()
local root_node = tracked_bufs[buf].tree:root()
return root_node:named_descendant_for_range(row, col, row, col)
end

--- @return node_stack ns A list of types of nodes. The first element is the type of deepest node.
local function get_node_stack()
local node = get_node()
local stack = {}
while node do
table.insert(stack, node:type())
node = node:parent()
end
return stack
end

--- @param node_stack node_stack
--- @return boolean
--- Check for a given node stack whether we are in math
local function stack_is_in_math(node_stack)
for _, node_type in ipairs(node_stack) do
if TEXT_NODES[node_type] then
return false
end
if MATH_NODES[node_type] then
return true
end
end
return false
end

--- @param node_stack node_stack
--- @param check_parent boolean
--- @return boolean
--- Check for a given node stack whether we are in text
local function stack_is_in_text(node_stack, check_parent)
for _, node_type in ipairs(node_stack) do
if TEXT_NODES[node_type] and not check_parent then
return true
end
if MATH_NODES[node_type] then
return false
end
end
return true
end

--- @param buf buffer buffer handle
--- Create a callback function which will track whether user is currently
--- editing math
--- @return function
local function mk_callback(buf)
return function()
local current_buf = vim.api.nvim_get_current_buf()
if buf ~= current_buf then
return
end

local parser = tracked_bufs[buf].parser
tracked_bufs[buf].tree = parser:parse()[1]
local node_stack = get_node_stack()
tracked_bufs[buf].node_stack = node_stack
tracked_bufs[buf].in_math = stack_is_in_math(node_stack)
tracked_bufs[buf].in_text = stack_is_in_text(node_stack, true)
end
end

--- @param buf buffer Buffer handle
--- Start tracking a buffer.
--- Assert that given buffer was not tracked before.
local function init_buf(buf)
assert(not tracked_bufs[buf], "Attempt to initialize already tracked buffer")
local success, parser = pcall(vim.treesitter.get_parser, buf, "latex")
if not success then
vim.notify("Could not load latex treesitter parser. Is it installed?\n TSInstall latex", vim.log.levels.ERROR)
error(parser, 2)
end
local timer = vim.loop.new_timer()
timer:start(0, 100, vim.schedule_wrap(mk_callback(buf)))

tracked_bufs[buf] = {
timer = timer,
parser = parser,
tree = nil,
node_stack = {},
in_math = false,
in_text = true,
}
end

--- @param buf buffer Buffer handle
--- Stop tracking a buffer.
--- Assert that given buffer was tracked before.
local function deinit_buf(buf)
assert(tracked_bufs[buf], "Attempt to deinitialize buffer that is not tracked")
tracked_bufs[buf].timer:stop()
tracked_bufs[buf].timer:close()
tracked_bufs[buf].parser:destroy()
tracked_bufs[buf] = nil
end

--- @param buf buffer
--- @return boolean b Whether buffer is tracked by the plugin
local function is_buf_tracked(buf)
return tracked_bufs[buf] ~= nil
end

M.polling = {
init_buf = init_buf,
deinit_buf = deinit_buf,
is_buf_tracked = is_buf_tracked,
tracked_bufs = tracked_bufs,
}

return M