-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
pspnet.py
282 lines (247 loc) · 10.7 KB
/
pspnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
# pylint: disable=unused-argument
"""Pyramid Scene Parsing Network"""
from mxnet.gluon import nn
from mxnet.context import cpu
from mxnet.gluon.nn import HybridBlock
from .segbase import SegBaseModel
from .fcn import _FCNHead
# pylint: disable-all
__all__ = ['PSPNet', 'get_psp', 'get_psp_resnet101_coco', 'get_psp_resnet101_voc',
'get_psp_resnet50_ade', 'get_psp_resnet101_ade', 'get_psp_resnet101_citys']
class PSPNet(SegBaseModel):
r"""Pyramid Scene Parsing Network
Parameters
----------
nclass : int
Number of categories for the training dataset.
backbone : string
Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
'resnet101' or 'resnet152').
norm_layer : object
Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
for Synchronized Cross-GPU BachNormalization).
aux : bool
Auxiliary loss.
Reference:
Zhao, Hengshuang, Jianping Shi, Xiaojuan Qi, Xiaogang Wang, and Jiaya Jia.
"Pyramid scene parsing network." *CVPR*, 2017
"""
def __init__(self, nclass, backbone='resnet50', aux=True, ctx=cpu(), pretrained_base=True,
base_size=520, crop_size=480, **kwargs):
super(PSPNet, self).__init__(nclass, aux, backbone, ctx=ctx, base_size=base_size,
crop_size=crop_size, pretrained_base=pretrained_base, **kwargs)
with self.name_scope():
self.head = _PSPHead(nclass, feature_map_height=self._up_kwargs['height']//8,
feature_map_width=self._up_kwargs['width']//8, **kwargs)
self.head.initialize(ctx=ctx)
self.head.collect_params().setattr('lr_mult', 10)
if self.aux:
self.auxlayer = _FCNHead(1024, nclass, **kwargs)
self.auxlayer.initialize(ctx=ctx)
self.auxlayer.collect_params().setattr('lr_mult', 10)
print('self.crop_size', self.crop_size)
def hybrid_forward(self, F, x):
c3, c4 = self.base_forward(x)
outputs = []
x = self.head(c4)
x = F.contrib.BilinearResize2D(x, **self._up_kwargs)
outputs.append(x)
if self.aux:
auxout = self.auxlayer(c3)
auxout = F.contrib.BilinearResize2D(auxout, **self._up_kwargs)
outputs.append(auxout)
return tuple(outputs)
def demo(self, x):
return self.predict(x)
def predict(self, x):
h, w = x.shape[2:]
self._up_kwargs['height'] = h
self._up_kwargs['width'] = w
c3, c4 = self.base_forward(x)
outputs = []
x = self.head.demo(c4)
import mxnet.ndarray as F
pred = F.contrib.BilinearResize2D(x, **self._up_kwargs)
return pred
def _PSP1x1Conv(in_channels, out_channels, norm_layer, norm_kwargs):
block = nn.HybridSequential()
with block.name_scope():
block.add(nn.Conv2D(in_channels=in_channels, channels=out_channels,
kernel_size=1, use_bias=False))
block.add(norm_layer(in_channels=out_channels, **({} if norm_kwargs is None else norm_kwargs)))
block.add(nn.Activation('relu'))
return block
class _PyramidPooling(HybridBlock):
def __init__(self, in_channels, height=60, width=60, **kwargs):
super(_PyramidPooling, self).__init__()
out_channels = int(in_channels/4)
self._up_kwargs = {'height': height, 'width': width}
with self.name_scope():
self.conv1 = _PSP1x1Conv(in_channels, out_channels, **kwargs)
self.conv2 = _PSP1x1Conv(in_channels, out_channels, **kwargs)
self.conv3 = _PSP1x1Conv(in_channels, out_channels, **kwargs)
self.conv4 = _PSP1x1Conv(in_channels, out_channels, **kwargs)
def pool(self, F, x, size):
return F.contrib.AdaptiveAvgPooling2D(x, output_size=size)
def upsample(self, F, x):
return F.contrib.BilinearResize2D(x, **self._up_kwargs)
def hybrid_forward(self, F, x):
feat1 = self.upsample(F, self.conv1(self.pool(F, x, 1)))
feat2 = self.upsample(F, self.conv2(self.pool(F, x, 2)))
feat3 = self.upsample(F, self.conv3(self.pool(F, x, 3)))
feat4 = self.upsample(F, self.conv4(self.pool(F, x, 6)))
return F.concat(x, feat1, feat2, feat3, feat4, dim=1)
def demo(self, x):
self._up_kwargs['height'] = x.shape[2]
self._up_kwargs['width'] = x.shape[3]
import mxnet.ndarray as F
feat1 = self.upsample(F, self.conv1(self.pool(F, x, 1)))
feat2 = self.upsample(F, self.conv2(self.pool(F, x, 2)))
feat3 = self.upsample(F, self.conv3(self.pool(F, x, 3)))
feat4 = self.upsample(F, self.conv4(self.pool(F, x, 6)))
return F.concat(x, feat1, feat2, feat3, feat4, dim=1)
class _PSPHead(HybridBlock):
def __init__(self, nclass, norm_layer=nn.BatchNorm, norm_kwargs=None,
feature_map_height=60, feature_map_width=60, **kwargs):
super(_PSPHead, self).__init__()
self.psp = _PyramidPooling(2048, height=feature_map_height, width=feature_map_width,
norm_layer=norm_layer,
norm_kwargs=norm_kwargs)
with self.name_scope():
self.block = nn.HybridSequential(prefix='')
self.block.add(nn.Conv2D(in_channels=4096, channels=512,
kernel_size=3, padding=1, use_bias=False))
self.block.add(norm_layer(in_channels=512, **({} if norm_kwargs is None else norm_kwargs)))
self.block.add(nn.Activation('relu'))
self.block.add(nn.Dropout(0.1))
self.block.add(nn.Conv2D(in_channels=512, channels=nclass,
kernel_size=1))
def hybrid_forward(self, F, x):
x = self.psp(x)
return self.block(x)
def demo(self, x):
x = self.psp.demo(x)
return self.block(x)
def get_psp(dataset='pascal_voc', backbone='resnet50', pretrained=False,
root='~/.mxnet/models', ctx=cpu(0), pretrained_base=True, **kwargs):
r"""Pyramid Scene Parsing Network
Parameters
----------
dataset : str, default pascal_voc
The dataset that model pretrained on. (pascal_voc, ade20k)
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
pretrained_base : bool or str, default True
This will load pretrained backbone network, that was trained on ImageNet.
Examples
--------
>>> model = get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False)
>>> print(model)
"""
acronyms = {
'pascal_voc': 'voc',
'pascal_aug': 'voc',
'ade20k': 'ade',
'coco': 'coco',
'citys': 'citys',
}
from ..data import datasets
# infer number of classes
model = PSPNet(datasets[dataset].NUM_CLASS, backbone=backbone,
pretrained_base=pretrained_base, ctx=ctx, **kwargs)
model.classes = datasets[dataset].CLASSES
if pretrained:
from .model_store import get_model_file
model.load_parameters(get_model_file('psp_%s_%s'%(backbone, acronyms[dataset]),
tag=pretrained, root=root), ctx=ctx)
return model
def get_psp_resnet101_coco(**kwargs):
r"""Pyramid Scene Parsing Network
Parameters
----------
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_psp_resnet101_coco(pretrained=True)
>>> print(model)
"""
return get_psp('coco', 'resnet101', **kwargs)
def get_psp_resnet101_voc(**kwargs):
r"""Pyramid Scene Parsing Network
Parameters
----------
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_psp_resnet101_voc(pretrained=True)
>>> print(model)
"""
return get_psp('pascal_voc', 'resnet101', **kwargs)
def get_psp_resnet50_ade(**kwargs):
r"""Pyramid Scene Parsing Network
Parameters
----------
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_psp_resnet50_ade(pretrained=True)
>>> print(model)
"""
return get_psp('ade20k', 'resnet50', **kwargs)
def get_psp_resnet101_ade(**kwargs):
r"""Pyramid Scene Parsing Network
Parameters
----------
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_psp_resnet101_ade(pretrained=True)
>>> print(model)
"""
return get_psp('ade20k', 'resnet101', **kwargs)
def get_psp_resnet101_citys(**kwargs):
r"""Pyramid Scene Parsing Network
Parameters
----------
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_psp_resnet101_ade(pretrained=True)
>>> print(model)
"""
return get_psp('citys', 'resnet101', **kwargs)