Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ PRs are welcomed for other git host websites!
2. New Features:
- Windows (+wsl2) support.
- Blame support.
- Full [git protocols](https://git-scm.com/book/en/v2/Git-on-the-Server-The-Protocols) support.
- Respect ssh host alias.
- Add `?plain=1` for markdown files.
- Fully customizable [git url](https://git-scm.com/book/en/v2/Git-on-the-Server-The-Protocols) generation.
3. Improvements:
- Use git `stderr` output as error message.
- Async child process IO via coroutine and `uv.spawn`.
Expand Down
236 changes: 71 additions & 165 deletions lua/gitlinker/commons/async.lua
Original file line number Diff line number Diff line change
@@ -1,189 +1,95 @@
---@diagnostic disable
--- Small async library for Neovim plugins

local function validate_callback(func, callback)
if callback and type(callback) ~= 'function' then
local info = debug.getinfo(func, 'nS')
error(
string.format(
'Callback is not a function for %s, got: %s',
info.short_src .. ':' .. info.linedefined,
vim.inspect(callback)
)
)
end
-- Copied from: <https://github.com/neovim/neovim/issues/19624#issuecomment-1202405058>

local co = coroutine

local async_thread = {
threads = {},
}

local function threadtostring(x)
if jit then
return string.format('%p', x)
else
return tostring(x):match('thread: (.*)')
end
end

-- Coroutine.running() was changed between Lua 5.1 and 5.2:
-- - 5.1: Returns the running coroutine, or nil when called by the main thread.
-- - 5.2: Returns the running coroutine plus a boolean, true when the running
-- coroutine is the main one.
--
-- For LuaJIT, 5.2 behaviour is enabled with LUAJIT_ENABLE_LUA52COMPAT
--
-- We need to handle both.
local _main_co_or_nil = coroutine.running()

--- Executes a future with a callback when it is done
--- @param func function
--- @param callback function?
--- @param ... any
local function run(func, callback, ...)
validate_callback(func, callback)
function async_thread.running()
local thread = co.running()
local id = threadtostring(thread)
return async_thread.threads[id]
end

local co = coroutine.create(func)
function async_thread.create(fn)
local thread = co.create(fn)
local id = threadtostring(thread)
async_thread.threads[id] = true
return thread
end

local function step(...)
local ret = { coroutine.resume(co, ...) }
local stat = ret[1]
function async_thread.finished(x)
if co.status(x) == 'dead' then
local id = threadtostring(x)
async_thread.threads[id] = nil
return true
end
return false
end

if not stat then
local err = ret[2] --[[@as string]]
error(
string.format('The coroutine failed with this message: %s\n%s', err, debug.traceback(co))
)
end
--- @param async_fn function
--- @param ... any
local function execute(async_fn, ...)
local thread = async_thread.create(async_fn)

local function step(...)
local ret = { co.resume(thread, ...) }
local stat, err_or_fn, nargs = unpack(ret)

if not stat then
error(string.format("The coroutine failed with this message: %s\n%s",
err_or_fn, debug.traceback(thread)))
end

if coroutine.status(co) == 'dead' then
if callback then
callback(unpack(ret, 2, table.maxn(ret)))
if async_thread.finished(thread) then
return
end
return
end

--- @type integer, fun(...: any): any
local nargs, fn = ret[2], ret[3]
assert(type(fn) == 'function', 'type error :: expected func')
assert(type(err_or_fn) == "function", "The 1st parameter must be a lua function")

--- @type any[]
local args = { unpack(ret, 4, table.maxn(ret)) }
args[nargs] = step
fn(unpack(args, 1, nargs))
end
local ret_fn = err_or_fn
local args = { select(4, unpack(ret)) }
args[nargs] = step
ret_fn(unpack(args, 1, nargs --[[@as integer]]))
end

step(...)
step(...)
end

local M = {}

---Use this to create a function which executes in an async context but
---called from a non-async context. Inherently this cannot return anything
---since it is non-blocking
--- @generic F: function
--- @param argc integer
--- @param func async F
--- @return F
function M.sync(argc, func)
return function(...)
assert(not coroutine.running())
local callback = select(argc + 1, ...)
run(func, callback, unpack({ ... }, 1, argc))
end
end

--- @param argc integer
--- @param func function
--- @param ... any
--- @return any ...
function M.wait(argc, func, ...)
-- Always run the wrapped functions in xpcall and re-raise the error in the
-- coroutine. This makes pcall work as normal.
local function pfunc(...)
local args = { ... } --- @type any[]
local cb = args[argc]
args[argc] = function(...)
cb(true, ...)
end
xpcall(func, function(err)
cb(false, err, debug.traceback())
end, unpack(args, 1, argc))
end

local ret = { coroutine.yield(argc, pfunc, ...) }

local ok = ret[1]
if not ok then
--- @type string, string
local err, traceback = ret[2], ret[3]
error(string.format('Wrapped function failed: %s\n%s', err, traceback))
end

return unpack(ret, 2, table.maxn(ret))
end

function M.run(func, ...)
return run(func, nil, ...)
end

--- Creates an async function with a callback style function.
--- @param argc integer
--- @param func function
--- @return function
function M.wrap(argc, func)
assert(type(argc) == 'number')
assert(type(func) == 'function')
return function(...)
return M.wait(argc, func, ...)
end
end

--- @generic R
--- @param n integer Mx number of jobs to run concurrently
--- @param thunks (fun(cb: function): R)[]
--- @param interrupt_check fun()?
--- @param callback fun(ret: R[][])
M.join = M.wrap(4, function(n, thunks, interrupt_check, callback)
n = math.min(n, #thunks)

local ret = {} --- @type any[][]

if #thunks == 0 then
callback(ret)
return
end

local remaining = { unpack(thunks, n + 1) }
local to_go = #thunks

local function cb(...)
ret[#ret + 1] = { ... }
to_go = to_go - 1
if to_go == 0 then
callback(ret)
elseif not interrupt_check or not interrupt_check() then
if #remaining > 0 then
local next_thunk = table.remove(remaining, 1)
next_thunk(cb)
M.wrap = function(func, argc)
return function(...)
if not async_thread.running() then
return func(...)
end
end
end

for i = 1, n do
thunks[i](cb)
end
end)
return co.yield(func, argc, ...)
end
end

---Useful for partially applying arguments to an async function
--- @param fn function
--- @param ... any
--- @param func function
--- @return function
function M.curry(fn, ...)
--- @type integer, any[]
local nargs, args = select('#', ...), { ... }

return function(...)
local other = { ... }
for i = 1, select('#', ...) do
args[nargs + i] = other[i]
end
return fn(unpack(args))
end
M.void = function(func)
return function(...)
if async_thread.running() then
return func(...)
end
execute(func, ...)
end
end

if vim.schedule then
--- An async function that when called will yield to the Neovim scheduler to be
--- able to call the API.
M.schedule = M.wrap(1, vim.schedule)
end
M.schedule = M.wrap(vim.schedule, 1)

return M
2 changes: 1 addition & 1 deletion lua/gitlinker/commons/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
26.0.0
27.0.0
58 changes: 58 additions & 0 deletions lua/gitlinker/git.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
local logging = require("gitlinker.commons.logging")
local spawn = require("gitlinker.commons.spawn")
local uv = require("gitlinker.commons.uv")
local str = require("gitlinker.commons.str")

local async = require("gitlinker.async")

Expand Down Expand Up @@ -358,12 +359,65 @@
return result.stdout[1]
end

-- --- NOTE: async functions for `vim.ui.select`.
-- local _run_select = async.wrap(function(remotes, callback)
-- vim.ui.select(remotes, {
-- prompt = "Detect multiple git remotes:",
-- format_item = function(item)
-- return item
-- end,
-- }, function(choice)
-- callback(choice)
-- end)
-- end, 2)
--
-- -- wrap the select function.
-- --- @package
-- --- @type fun(remotes:string[]):string?
-- local function run_select(remotes)
-- return _run_select(remotes)
-- end

--- @package
--- @param remotes string[]
--- @param cwd string?
--- @return string?
local function _select_remotes(remotes, cwd)
local logger = logging.get("gitlinker")

Check warning on line 386 in lua/gitlinker/git.lua

View check run for this annotation

Codecov / codecov/patch

lua/gitlinker/git.lua#L386

Added line #L386 was not covered by tests
-- local result = run_select(remotes)

local formatted_remotes = { "Please select remote index:" }
for i, remote in ipairs(remotes) do
local remote_url = get_remote_url(remote, cwd)
table.insert(formatted_remotes, string.format("%d. %s (%s)", i, remote, remote_url))

Check warning on line 392 in lua/gitlinker/git.lua

View check run for this annotation

Codecov / codecov/patch

lua/gitlinker/git.lua#L389-L392

Added lines #L389 - L392 were not covered by tests
end

async.scheduler()
local result = vim.fn.inputlist(formatted_remotes)

Check warning on line 396 in lua/gitlinker/git.lua

View check run for this annotation

Codecov / codecov/patch

lua/gitlinker/git.lua#L395-L396

Added lines #L395 - L396 were not covered by tests
-- logger:debug(string.format("inputlist:%s(%s)", vim.inspect(result), vim.inspect(type(result))))

if type(result) ~= "number" or result < 1 or result > #remotes then
logger:err("fatal: user cancelled multiple git remotes")
return nil

Check warning on line 401 in lua/gitlinker/git.lua

View check run for this annotation

Codecov / codecov/patch

lua/gitlinker/git.lua#L399-L401

Added lines #L399 - L401 were not covered by tests
end

for i, remote in ipairs(remotes) do
if result == i then
return remote

Check warning on line 406 in lua/gitlinker/git.lua

View check run for this annotation

Codecov / codecov/patch

lua/gitlinker/git.lua#L404-L406

Added lines #L404 - L406 were not covered by tests
end
end

logger:err("fatal: user cancelled multiple git remotes, please select an index")
return nil

Check warning on line 411 in lua/gitlinker/git.lua

View check run for this annotation

Codecov / codecov/patch

lua/gitlinker/git.lua#L410-L411

Added lines #L410 - L411 were not covered by tests
end

--- @param cwd string?
--- @return string?
local function get_branch_remote(cwd)
local logger = logging.get("gitlinker")
-- origin/upstream
local remotes = _get_remote(cwd)
logger:debug(string.format("git remotes:%s", vim.inspect(remotes)))
if not remotes then
return nil
end
Expand All @@ -372,6 +426,10 @@
return remotes[1]
end

if #remotes > 1 then
return _select_remotes(remotes, cwd)

Check warning on line 430 in lua/gitlinker/git.lua

View check run for this annotation

Codecov / codecov/patch

lua/gitlinker/git.lua#L429-L430

Added lines #L429 - L430 were not covered by tests
end

-- origin/linrongbin16/add-rule2
local upstream_branch = _get_rev_name("@{u}", cwd)
if not upstream_branch then
Expand Down
Loading