Skip to content

Commit

Permalink
Assert that weights and gradWeights line up in getParameters
Browse files Browse the repository at this point in the history
Adds a check that the parameters have the same offset as their gradients
after getParameters is called. If they do not line up, then methods such
as torch/optim will not work. This could happen if the sharing of
weights and gradWeights do not match or, due to a bug in the
implementation of getParameters, if the storages of weights and
gradWeights do not closely correspond.

Fix getParameters tests to always share gradWeights when sharing
weights.
  • Loading branch information
colesbury committed Mar 4, 2016
1 parent 2adf0ef commit 70f10ac
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 34 deletions.
9 changes: 8 additions & 1 deletion Module.lua
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,14 @@ end
function Module:getParameters()
-- get parameters
local parameters,gradParameters = self:parameters()
return Module.flatten(parameters), Module.flatten(gradParameters)
local p, g = Module.flatten(parameters), Module.flatten(gradParameters)
assert(p:nElement() == g:nElement(),
'check that you are sharing parameters and gradParameters')
for i=1,#parameters do
assert(parameters[i]:storageOffset() == gradParameters[i]:storageOffset(),
'misaligned parameter at ' .. tostring(i))
end
return p, g
end

function Module:__call__(input, gradOutput)
Expand Down
68 changes: 35 additions & 33 deletions test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3736,7 +3736,7 @@ end
function nntest.Module_getParameters_5()
local n = nn.Sequential()
n:add( nn.Linear(10,10) )
n:add( n.modules[1]:clone('weight','bias') )
n:add( n.modules[1]:clone('weight','bias','gradWeight','gradBias') )
local p = n:getParameters()

mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'error when using cloning+sharing')
Expand All @@ -3756,7 +3756,7 @@ end
function nntest.Module_getParameters_6()
local n = nn.Sequential()
n:add( nn.Linear(10,10) )
n:add( n.modules[1]:clone('weight','bias') )
n:add( n.modules[1]:clone('weight','bias','gradWeight','gradBias') )
local _ = n:getParameters()

n:add(nn.Linear(10,10))
Expand All @@ -3777,7 +3777,7 @@ end
function nntest.Module_getParameters_7()
local n = nn.Sequential()
n:add( nn.Linear(10,10) )
n:add( n.modules[1]:clone('weight','bias') )
n:add( n.modules[1]:clone('weight','bias','gradWeight','gradBias') )
local _ = n:getParameters()

n:add(nn.Linear(10,10))
Expand Down Expand Up @@ -3842,42 +3842,44 @@ function nntest.Module_getParameters_8()
end

function nntest.Module_getParameters_10()
-- tensors are non-contiguous but compact; they can be gathered
local L = nn.Linear(10,10)
L.weight = torch.Tensor(10,10):t():fill(1)
local tmp = torch.Tensor(10,10):fill(2)
L.bias = tmp:select(1,2)
local P = L:getParameters()
mytester:asserteq(L.weight:mean(), 1)
mytester:asserteq(L.bias:mean(), 2)
mytester:asserteq(L.weight:storage(), L.bias:storage())
mytester:asserteq(P:nElement(), 110)
mytester:asserteq(P:storage():size(), 110)
mytester:assertlt(L.bias[{ {10} }]:storageOffset() - 1, L.bias:storage():size())
-- tensors are non-contiguous but compact; they can be gathered
local L = nn.Linear(10,10)
L.weight = torch.Tensor(10,10):t():fill(1)
local tmp = torch.Tensor(10,10):fill(2)
L.bias = tmp:select(1,2)
local P = L:getParameters()
mytester:asserteq(L.weight:mean(), 1)
mytester:asserteq(L.bias:mean(), 2)
mytester:asserteq(L.weight:storage(), L.bias:storage())
mytester:asserteq(P:nElement(), 110)
mytester:asserteq(P:storage():size(), 110)
mytester:assertlt(L.bias[{ {10} }]:storageOffset() - 1, L.bias:storage():size())
end

function nntest.Module_getParameters_11()
-- tensors are non-compact; they can't be gathered
local L = nn.Linear(10,10)
local tmp = torch.Tensor(10,10):fill(2)
L.bias = tmp:select(2,2)
local ok, err = pcall(L.getParameters, L)
mytester:assert(not ok)
-- tensors are non-compact; they can't be gathered
local L = nn.Linear(10,10)
local tmp = torch.Tensor(10,10):fill(2)
L.bias = tmp:select(2,2)
local ok, err = pcall(L.getParameters, L)
mytester:assert(not ok)
end

function nntest.Module_getParameters_12()
-- tensors are expanded (i.e. have dimension 0)
local L = nn.Linear(10,10)
L.weight = torch.Tensor(10, 1):fill(1)
torch.expand(L.weight, 10, 10)
L.bias = torch.Tensor(10):fill(2)
local P = L:getParameters()
mytester:asserteq(L.weight:mean(), 1)
mytester:asserteq(L.bias:mean(), 2)
mytester:asserteq(L.weight:storage(), L.bias:storage())
mytester:asserteq(P:nElement(), 20)
mytester:asserteq(P:storage():size(), 20)
mytester:assertlt(L.bias[{ {10} }]:storageOffset() - 1, L.bias:storage():size())
-- tensors are expanded (i.e. have dimension 0)
local L = nn.Linear(10,10)
L.weight = torch.Tensor(10, 1):fill(1)
torch.expand(L.weight, 10, 10)
L.gradWeight = torch.Tensor(10, 1):fill(1)
torch.expand(L.gradWeight, 10, 10)
L.bias = torch.Tensor(10):fill(2)
local P = L:getParameters()
mytester:asserteq(L.weight:mean(), 1)
mytester:asserteq(L.bias:mean(), 2)
mytester:asserteq(L.weight:storage(), L.bias:storage())
mytester:asserteq(P:nElement(), 20)
mytester:asserteq(P:storage():size(), 20)
mytester:assertlt(L.bias[{ {10} }]:storageOffset() - 1, L.bias:storage():size())
end

function nntest.Module_listModules()
Expand Down

0 comments on commit 70f10ac

Please sign in to comment.