-
Notifications
You must be signed in to change notification settings - Fork 40
/
layers.py
85 lines (70 loc) · 3.01 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from __future__ import absolute_import, division
import torch
import torch.nn as nn
import numpy as np
from torch_deform_conv.deform_conv import th_batch_map_offsets, th_generate_grid
class ConvOffset2D(nn.Conv2d):
"""ConvOffset2D
Convolutional layer responsible for learning the 2D offsets and output the
deformed feature map using bilinear interpolation
Note that this layer does not perform convolution on the deformed feature
map. See get_deform_cnn in cnn.py for usage
"""
def __init__(self, filters, init_normal_stddev=0.01, **kwargs):
"""Init
Parameters
----------
filters : int
Number of channel of the input feature map
init_normal_stddev : float
Normal kernel initialization
**kwargs:
Pass to superclass. See Con2d layer in pytorch
"""
self.filters = filters
self._grid_param = None
super(ConvOffset2D, self).__init__(self.filters, self.filters*2, 3, padding=1, bias=False, **kwargs)
self.weight.data.copy_(self._init_weights(self.weight, init_normal_stddev))
def forward(self, x):
"""Return the deformed featured map"""
x_shape = x.size()
offsets = super(ConvOffset2D, self).forward(x)
# offsets: (b*c, h, w, 2)
offsets = self._to_bc_h_w_2(offsets, x_shape)
# x: (b*c, h, w)
x = self._to_bc_h_w(x, x_shape)
# X_offset: (b*c, h, w)
x_offset = th_batch_map_offsets(x, offsets, grid=self._get_grid(self,x))
# x_offset: (b, h, w, c)
x_offset = self._to_b_c_h_w(x_offset, x_shape)
return x_offset
@staticmethod
def _get_grid(self, x):
batch_size, input_height, input_width = x.size(0), x.size(1), x.size(2)
dtype, cuda = x.data.type(), x.data.is_cuda
if self._grid_param == (batch_size, input_height, input_width, dtype, cuda):
return self._grid
self._grid_param = (batch_size, input_height, input_width, dtype, cuda)
self._grid = th_generate_grid(batch_size, input_height, input_width, dtype, cuda)
return self._grid
@staticmethod
def _init_weights(weights, std):
fan_out = weights.size(0)
fan_in = weights.size(1) * weights.size(2) * weights.size(3)
w = np.random.normal(0.0, std, (fan_out, fan_in))
return torch.from_numpy(w.reshape(weights.size()))
@staticmethod
def _to_bc_h_w_2(x, x_shape):
"""(b, 2c, h, w) -> (b*c, h, w, 2)"""
x = x.contiguous().view(-1, int(x_shape[2]), int(x_shape[3]), 2)
return x
@staticmethod
def _to_bc_h_w(x, x_shape):
"""(b, c, h, w) -> (b*c, h, w)"""
x = x.contiguous().view(-1, int(x_shape[2]), int(x_shape[3]))
return x
@staticmethod
def _to_b_c_h_w(x, x_shape):
"""(b*c, h, w) -> (b, c, h, w)"""
x = x.contiguous().view(-1, int(x_shape[1]), int(x_shape[2]), int(x_shape[3]))
return x