Skip to content

Commit

Permalink
fix: couple bugs with multiple results and varargs (...) (nvim-lua#515)
Browse files Browse the repository at this point in the history
* fix(vararg.rotate): edge cases and nil arguments

- zero argument rotation returned one value (a global by the name A0)
- the generic fallback dropped the first argument and trailing nils

* fix: functional.partial & fun.bind

These only worked for binding exactly one parameter.

* Include generated rotate.lua file in linting
  • Loading branch information
juntuu authored Sep 10, 2023
1 parent e739a2e commit 23deb47
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 30 deletions.
1 change: 0 additions & 1 deletion .luacheckrc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ read_globals = {
exclude_files = {
"lua/plenary/profile/lua_profiler.lua",
"lua/plenary/profile/memory_profiler.lua",
"lua/plenary/vararg/rotate.lua",
"lua/plenary/async_lib/*.lua",
}

Expand Down
16 changes: 1 addition & 15 deletions lua/plenary/fun.lua
Original file line number Diff line number Diff line change
@@ -1,20 +1,6 @@
local tbl = require "plenary.tbl"

local M = {}

function M.bind(fn, ...)
if select("#", ...) == 1 then
local arg = ...
return function(...)
fn(arg, ...)
end
end

local args = tbl.pack(...)
return function(...)
fn(tbl.unpack(args), ...)
end
end
M.bind = require("plenary.functional").partial

function M.arify(fn, argc)
return function(...)
Expand Down
14 changes: 10 additions & 4 deletions lua/plenary/functional.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,17 @@ function f.join(array, sep)
return table.concat(vim.tbl_map(tostring, array), sep)
end

function f.partial(fun, ...)
local args = { ... }
return function(...)
return fun(unpack(args), ...)
local function bind_n(fn, n, a, ...)
if n == 0 then
return fn
end
return bind_n(function(...)
return fn(a, ...)
end, n - 1, ...)
end

function f.partial(fun, ...)
return bind_n(fun, select("#", ...), ...)
end

function f.any(fun, iterable)
Expand Down
12 changes: 6 additions & 6 deletions lua/plenary/vararg/rotate.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@ local tbl = require "plenary.tbl"

local rotate_lookup = {}

rotate_lookup[0] = function()
return A0
end

rotate_lookup[1] = function(A0)
return A0
end
Expand Down Expand Up @@ -71,12 +67,16 @@ rotate_lookup[15] = function(A0, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A
end

local function rotate_n(first, ...)
local n = select("#", ...) + 1
local args = tbl.pack(...)
args[#args + 1] = first
return tbl.unpack(args)
args[n] = first
return tbl.unpack(args, 1, n)
end

local function rotate(nargs, ...)
if nargs == nil or nargs < 1 then
return
end
return (rotate_lookup[nargs] or rotate_n)(...)
end

Expand Down
10 changes: 7 additions & 3 deletions scripts/vararg/rotate.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,23 @@ local tbl = require('plenary.tbl')

local rotate_lookup = {}

{% for n in range(0, amount) %}
{% for n in range(1, amount) %}
rotate_lookup[{{n}}] = function ({% for n in range(n) %} A{{n}} {{ ", " if not loop.last else "" }} {% endfor %})
return {% for n in range(1, n) %} A{{n}}, {% endfor %} A0
end
{% endfor %}

local function rotate_n(first, ...)
local n = select("#", ...) + 1
local args = tbl.pack(...)
args[#args+1] = first
return tbl.unpack(args)
args[n] = first
return tbl.unpack(args, 1, n)
end

local function rotate(nargs, ...)
if nargs == nil or nargs < 1 then
return
end
return (rotate_lookup[nargs] or rotate_n)(...)
end

Expand Down
18 changes: 18 additions & 0 deletions tests/plenary/functional_spec.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
local f = require "plenary.functional"

describe("functional", function()
describe("partial", function()
local function args(...)
assert.is.equal(4, select("#", ...))
return table.concat({ ... }, ",")
end
it("should bind correct parameters", function()
local expected = args(1, 2, 3, 4)
assert.is.equal(expected, f.partial(args)(1, 2, 3, 4))
assert.is.equal(expected, f.partial(args, 1)(2, 3, 4))
assert.is.equal(expected, f.partial(args, 1, 2)(3, 4))
assert.is.equal(expected, f.partial(args, 1, 2, 3)(4))
assert.is.equal(expected, f.partial(args, 1, 2, 3, 4)())
end)
end)
end)
15 changes: 14 additions & 1 deletion tests/plenary/rotate_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,26 @@ local eq = function(a, b)
end

describe("rotate", function()
it("should return as many values, as the first argument", function()
local args = {}
for _ = 0, 20 do
local n = select("#", unpack(args))
assert.is.equal(n, select("#", rotate(n, unpack(args))))
args[#args + 1] = n
end
end)

it("should rotate varargs", function()
eq({ rotate(3, 1, 2, 3) }, { 2, 3, 1 })
eq({ rotate(9, 1, 2, 3, 4, 5, 6, 7, 8, 9) }, { 2, 3, 4, 5, 6, 7, 8, 9, 1 })
end)

it("should rotate zero", function()
assert.is.equal(0, select("#", rotate(0)))
end)

it("should rotate none", function()
eq({ rotate() }, {})
assert.is.equal(0, select("#", rotate()))
end)

it("should rotate one", function()
Expand Down

0 comments on commit 23deb47

Please sign in to comment.