Skip to content

Commit

Permalink
SelectTable fix (torch#761)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard authored and soumith committed Apr 11, 2016
1 parent ee1646f commit a195ac2
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 12 deletions.
16 changes: 6 additions & 10 deletions SelectTable.lua
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ local function zeroTableCopy(t1, t2)
if not t1[k] then
t1[k] = v:clone():zero()
else
local tensor = t1[k]
t1[k]:resizeAs(v)
t1[k]:zero()
end
Expand All @@ -40,16 +39,13 @@ local function zeroTableCopy(t1, t2)
end

function SelectTable:updateGradInput(input, gradOutput)
if self.index < 0 then
self.gradInput[#input + self.index + 1] = gradOutput
else
self.gradInput[self.index] = gradOutput
end
-- make gradInput a zeroed copy of input
zeroTableCopy(self.gradInput, input)

for i=#input+1, #self.gradInput do
self.gradInput[i] = nil
end
-- handle negative indices
local index = self.index < 0 and #input + self.index + 1 or self.index
-- copy into gradInput[index] (necessary for variable sized inputs)
assert(self.gradInput[index])
nn.utils.recursiveCopy(self.gradInput[index], gradOutput)

return self.gradInput
end
Expand Down
5 changes: 3 additions & 2 deletions test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4514,8 +4514,9 @@ function nntest.SelectTable()
local module1 = nn.SelectTable(-1)
local output1 = module1:forward(input1):clone()
local output2 = module1:forward(input2)
local gradInput1 = module1:backward(input1, gradOutput1)
for k,v in ipairs(gradInput1) do gradInput1[k] = v:clone() end
local gradInput_ = module1:backward(input1, gradOutput1)
local gradInput1 = {}
for k,v in ipairs(gradInput_) do gradInput1[k] = v:clone() end
local gradInput2 = module1:backward(input2, gradOutput2)

local module3 = nn.SelectTable(-1)
Expand Down
16 changes: 16 additions & 0 deletions utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,22 @@ function nn.utils.recursiveAdd(t1, val, t2)
return t1, t2
end

function nn.utils.recursiveCopy(t1,t2)
if torch.type(t2) == 'table' then
t1 = (torch.type(t1) == 'table') and t1 or {t1}
for key,_ in pairs(t2) do
t1[key], t2[key] = nn.utils.recursiveCopy(t1[key], t2[key])
end
elseif torch.isTensor(t2) then
t1 = torch.isTensor(t1) and t1 or t2.new()
t1:resizeAs(t2):copy(t2)
else
error("expecting nested tensors or tables. Got "..
torch.type(t1).." and "..torch.type(t2).." instead")
end
return t1, t2
end

function nn.utils.addSingletonDimension(...)
local view, t, dim
if select('#',...) < 3 then
Expand Down

0 comments on commit a195ac2

Please sign in to comment.