-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathBilinearSamplerPerspective.lua
121 lines (96 loc) · 3.83 KB
/
BilinearSamplerPerspective.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
-- code adapted from github
-- implemented by Yijie Guo ([email protected]) and Xinchen Yan ([email protected])
local BilinearSamplerPerspective, parent = torch.class('nn.BilinearSamplerPerspective', 'nn.Module')
--[[
BilinearSamplerBHWD() :
BilinearSamplerBHWD:updateOutput({inputImages, grids})
BilinearSamplerBHWD:updateGradInput({inputImages, grids}, gradOutput)
BilinearSamplerBHWD will perform bilinear sampling of the input images according to the
normalized coordinates provided in the grid. Output will be of same size as the grids,
with as many features as the input images.
- inputImages has to be in BDHWD layout
- grids have to be in BDHWD layout, with dim(D)=4
- grids contains, for each sample (first dim), the normalized coordinates of the output wrt the input sample
- Z, Y, X coordinate
- normalized coordinates : (-1,-1, -1) points to front top left, (-1, -1,1) points to front top right
- if the normalized coordinates fall outside of the image, then output will be filled with zeros
]]
function BilinearSamplerPerspective:__init(focal_length)
parent.__init(self)
self.gradInput={}
self.focal_length = focal_length
end
function BilinearSamplerPerspective:check(input, gradOutput)
local inputImages = input[1]
local grids = input[2]
assert(inputImages:isContiguous(), 'Input images have to be contiguous')
assert(inputImages:nDimension()==5)
assert(grids:nDimension()==5)
assert(inputImages:size(1)==grids:size(1)) -- batch
assert(grids:size(5)==4) -- coordinates
if gradOutput then
assert(grids:size(1)==gradOutput:size(1)) --batchsize
assert(grids:size(2)==gradOutput:size(2)) --depth == dist
assert(grids:size(3)==gradOutput:size(3)) --height
assert(grids:size(4)==gradOutput:size(4)) --width
end
end
local function addOuterDim(t)
local sizes = t:size()
local newsizes = torch.LongStorage(sizes:size()+1)
newsizes[1]=1
for i=1,sizes:size() do
newsizes[i+1]=sizes[i]
end
return t:view(newsizes)
end
function BilinearSamplerPerspective:updateOutput(input)
local _inputImages = input[1]
local _grids = input[2]
--print("D")
--print(_grids)
local inputImages, grids
if _inputImages:nDimension()==4 then --image:size(4) = channel, image:size(1)=depth, image:size(2) = height, image:size(3)=width
inputImages = addOuterDim(_inputImages)
grids = addOuterDim(_grids)
else
inputImages = _inputImages
grids = _grids
end
local input = {inputImages, grids}
self:check(input)
self.output:resize(inputImages:size(1), grids:size(2), grids:size(3), grids:size(4), inputImages:size(5))
inputImages.nn.BilinearSamplerPerspective_updateOutput(self, inputImages, grids, self.output, self.focal_length)
if _inputImages:nDimension()==4 then
self.output=self.output:select(1,1)
end
return self.output
end
function BilinearSamplerPerspective:updateGradInput(_input, _gradOutput)
local _inputImages = _input[1]
local _grids = _input[2]
local inputImages, grids, gradOutput
if _inputImages:nDimension()==4 then
inputImages = addOuterDim(_inputImages)
grids = addOuterDim(_grids)
gradOutput = addOuterDim(_gradOutput)
else
inputImages = _inputImages
grids = _grids
gradOutput = _gradOutput
end
local input = {inputImages, grids}
self:check(input, gradOutput)
for i=1,#input do
self.gradInput[i] = self.gradInput[i] or input[1].new()
self.gradInput[i]:resizeAs(input[i]):zero()
end
local gradInputImages = self.gradInput[1]
local gradGrids = self.gradInput[2]
inputImages.nn.BilinearSamplerPerspective_updateGradInput(self, inputImages, grids, gradInputImages, gradGrids, gradOutput, self.focal_length)
if _gradOutput:nDimension()==4 then
self.gradInput[1]=self.gradInput[1]:select(1,1)
self.gradInput[2]=self.gradInput[2]:select(1,1)
end
return self.gradInput
end