Skip to content

Commit

Permalink
add a field to support the evaluation interval (#849)
Browse files Browse the repository at this point in the history
  • Loading branch information
hellock authored Jun 22, 2019
1 parent e491713 commit d95727b
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
1 change: 1 addition & 0 deletions configs/mask_rcnn_r50_fpn_1x.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
evaluation = dict(interval=1)
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
Expand Down
14 changes: 9 additions & 5 deletions mmdet/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def build_optimizer(model, optimizer_cfg):
paramwise_options = optimizer_cfg.pop('paramwise_options', None)
# if no paramwise option is specified, just use the global setting
if paramwise_options is None:
return obj_from_dict(
optimizer_cfg, torch.optim, dict(params=model.parameters()))
return obj_from_dict(optimizer_cfg, torch.optim,
dict(params=model.parameters()))
else:
assert isinstance(paramwise_options, dict)
# get base lr and weight decay
Expand Down Expand Up @@ -154,15 +154,19 @@ def _dist_train(model, dataset, cfg, validate=False):
# register eval hooks
if validate:
val_dataset_cfg = cfg.data.val
eval_cfg = cfg.get('evaluation', {})
if isinstance(model.module, RPN):
# TODO: implement recall hooks for other datasets
runner.register_hook(CocoDistEvalRecallHook(val_dataset_cfg))
runner.register_hook(
CocoDistEvalRecallHook(val_dataset_cfg, **eval_cfg))
else:
dataset_type = getattr(datasets, val_dataset_cfg.type)
if issubclass(dataset_type, datasets.CocoDataset):
runner.register_hook(CocoDistEvalmAPHook(val_dataset_cfg))
runner.register_hook(
CocoDistEvalmAPHook(val_dataset_cfg, **eval_cfg))
else:
runner.register_hook(DistEvalmAPHook(val_dataset_cfg))
runner.register_hook(
DistEvalmAPHook(val_dataset_cfg, **eval_cfg))

if cfg.resume_from:
runner.resume(cfg.resume_from)
Expand Down
4 changes: 3 additions & 1 deletion mmdet/core/evaluation/eval_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,11 @@ class CocoDistEvalRecallHook(DistEvalHook):

def __init__(self,
dataset,
interval=1,
proposal_nums=(100, 300, 1000),
iou_thrs=np.arange(0.5, 0.96, 0.05)):
super(CocoDistEvalRecallHook, self).__init__(dataset)
super(CocoDistEvalRecallHook, self).__init__(
dataset, interval=interval)
self.proposal_nums = np.array(proposal_nums, dtype=np.int32)
self.iou_thrs = np.array(iou_thrs, dtype=np.float32)

Expand Down

0 comments on commit d95727b

Please sign in to comment.