Skip to content

Commit

Permalink
revise norm config
Browse files Browse the repository at this point in the history
  • Loading branch information
thangvubk committed Jan 10, 2019
1 parent 55a4feb commit 82e7545
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 33 deletions.
24 changes: 9 additions & 15 deletions configs/mask_rcnn_r50_fpn_gn_2x.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# model settings
normalize = dict(
type='GN',
num_groups=32,
frozen=False)

model = dict(
type='MaskRCNN',
pretrained='tools/resnet50-GN.path',
Expand All @@ -9,20 +14,13 @@
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch',
# Note: eval_mode and frozen are required args for backbone
normalize=dict(
type='GN',
num_groups=32,
eval_mode=False,
frozen=False)),
normalize=normalize),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5,
normalize=dict(
type='GN',
num_groups=32)),
normalize=normalize),
rpn_head=dict(
type='RPNHead',
in_channels=256,
Expand Down Expand Up @@ -50,9 +48,7 @@
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False,
normalize=dict(
type='GN',
num_groups=32)),
normalize=normalize),
mask_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
Expand All @@ -64,9 +60,7 @@
in_channels=256,
conv_out_channels=256,
num_classes=81,
normalize=dict(
type='GN',
num_groups=32)))
normalize=normalize))

# model training and testing settings
train_cfg = dict(
Expand Down
16 changes: 8 additions & 8 deletions mmdet/models/backbones/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,14 @@ class ResNet(nn.Module):
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
normalize (dict): dictionary to construct norm layer. Additionally,
eval mode and gradent freezing are controlled by
eval (bool) and frozen (bool) respectively.
normalize (dict): dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
"""

arch_settings = {
Expand All @@ -261,8 +264,8 @@ def __init__(self,
frozen_stages=-1,
normalize=dict(
type='BN',
eval_mode=True,
frozen=False),
norm_eval=True,
with_cp=False,
zero_init_residual=True):
super(ResNet, self).__init__()
Expand All @@ -278,11 +281,9 @@ def __init__(self,
assert max(out_indices) < num_stages
self.style = style
self.frozen_stages = frozen_stages
assert (isinstance(normalize, dict) and 'eval_mode' in normalize
and 'frozen' in normalize)
self.norm_eval = normalize.pop('eval_mode')
self.normalize = normalize
self.with_cp = with_cp
self.norm_eval = norm_eval
self.zero_init_residual = zero_init_residual
self.block, stage_blocks = self.arch_settings[depth]
self.stage_blocks = stage_blocks[:num_stages]
Expand Down Expand Up @@ -350,7 +351,6 @@ def init_weights(self, pretrained=None):
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
constant_init(m, 1)

# zero init for last norm layer https://arxiv.org/abs/1706.02677
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
Expand Down
9 changes: 6 additions & 3 deletions mmdet/models/backbones/resnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,14 @@ class ResNeXt(ResNet):
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
normalize (dict): dictionary to construct norm layer. Additionally,
eval mode and gradent freezing are controlled by
eval (bool) and frozen (bool) respectively.
normalize (dict): dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
"""

arch_settings = {
Expand Down
7 changes: 0 additions & 7 deletions mmdet/models/utils/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,6 @@ def build_norm_layer(cfg, num_features, postfix=''):
assert isinstance(cfg, dict) and 'type' in cfg
cfg_ = cfg.copy()

# eval_mode is supported and popped out for processing in module
# having pretrained weight only (e.g. backbone)
# raise an exception if eval_mode is in here
if 'eval_mode' in cfg:
raise Exception('eval_mode for modules without pretrained weights '
'is not supported')

layer_type = cfg_.pop('type')
if layer_type not in norm_cfg:
raise KeyError('Unrecognized norm type {}'.format(layer_type))
Expand Down

0 comments on commit 82e7545

Please sign in to comment.