Skip to content

Commit

Permalink
revise fast test & fix aug test bug
Browse files Browse the repository at this point in the history
  • Loading branch information
OceanPang committed Oct 10, 2018
1 parent 35cec76 commit d774325
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 12 deletions.
39 changes: 30 additions & 9 deletions mmdet/models/detectors/fast_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,33 @@ def __init__(self,
mask_head=None,
pretrained=None):
super(FastRCNN, self).__init__(
backbone=backbone,
neck=neck,
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
mask_roi_extractor=mask_roi_extractor,
mask_head=mask_head,
pretrained=pretrained)
backbone=backbone,
neck=neck,
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
mask_roi_extractor=mask_roi_extractor,
mask_head=mask_head,
pretrained=pretrained)

def forward_test(self, imgs, img_metas, proposals, **kwargs):
for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError('{} must be a list, but got {}'.format(
name, type(var)))

num_augs = len(imgs)
if num_augs != len(img_metas):
raise ValueError(
'num of augmentations ({}) != num of image meta ({})'.format(
len(imgs), len(img_metas)))
# TODO: remove the restriction of imgs_per_gpu == 1 when prepared
imgs_per_gpu = imgs[0].size(0)
assert imgs_per_gpu == 1

if num_augs == 1:
return self.simple_test(imgs[0], img_metas[0], proposals[0],
**kwargs)
else:
return self.aug_test(imgs, img_metas, proposals, **kwargs)
9 changes: 7 additions & 2 deletions mmdet/models/detectors/test_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels):

ori_shape = img_metas[0][0]['ori_shape']
segm_result = self.mask_head.get_seg_masks(
merged_masks, det_bboxes, det_labels, self.test_cfg.rcnn,
ori_shape)
merged_masks,
det_bboxes,
det_labels,
self.test_cfg.rcnn,
ori_shape,
scale_factor=1.0,
rescale=False)
return segm_result
2 changes: 1 addition & 1 deletion mmdet/models/detectors/two_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def simple_test(self, img, img_meta, proposals=None, rescale=False):

proposal_list = self.simple_test_rpn(
x, img_meta,
self.test_cfg.rpn) if proposals is None else proposals[0]
self.test_cfg.rpn) if proposals is None else proposals

det_bboxes, det_labels = self.simple_test_bboxes(
x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale)
Expand Down

0 comments on commit d774325

Please sign in to comment.