Skip to content

Commit

Permalink
Code for CVPR 2019 Paper "Libra R-CNN: Towards Balanced Learning for …
Browse files Browse the repository at this point in the history
…Object Detection" (#687)

* Code for components of Libra R-CNN

* Configs and README for Libra R-CNN

* update bfp

* Update Model ZOO

* add comments in non-local

* fix shape

* update bfp

* update according to ck's comments

* update des

* update des

* fix loss

* fix according to ck's comments

* fix activation in non-local

* fix conv_mask in non-local

* fix conv_mask in non-local

* Remove outdated model urls

* refactoring for bfp

* change in_channels from list[int] to int

* refactoring for nonlocal

* udpate weight init of nonlocal

* minor fix

* update new model urls
  • Loading branch information
OceanPang authored and hellock committed May 25, 2019
1 parent 4c5cfae commit b581e19
Show file tree
Hide file tree
Showing 17 changed files with 1,281 additions and 47 deletions.
8 changes: 4 additions & 4 deletions MODEL_ZOO.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,14 @@ Please refer to [Weight Standardization](configs/gn+ws/README.md) for details.

Please refer to [Deformable Convolutional Networks](configs/dcn/README.md) for details.

### Libra R-CNN

Please refer to [Libra R-CNN](configs/libra_rcnn/README.md) for details.

### Guided Anchoring

Please refer to [Guided Anchoring](configs/guided_anchoring/README.md) for details.


## Comparison with Detectron and maskrcnn-benchmark

We compare mmdetection with [Detectron](https://github.com/facebookresearch/Detectron)
Expand Down Expand Up @@ -454,6 +457,3 @@ and the main advantage is PyTorch itself. We also perform some memory optimizati

Note that Caffe2 and PyTorch have different apis to obtain memory usage with different implementations.
For all codebases, `nvidia-smi` shows a larger memory usage than the reported number in the above table.



1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ Results and models are available in the [Model zoo](MODEL_ZOO.md).
| RetinaNet |||||
| Hybrid Task Cascade|||||
| FCOS |||||
| Libra R-CNN |||||

Other features
- [x] DCNv2
Expand Down
26 changes: 26 additions & 0 deletions configs/libra_rcnn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Libra R-CNN: Towards Balanced Learning for Object Detection

## Introduction

We provide config files to reproduce the results in the CVPR 2019 paper [Libra R-CNN](https://arxiv.org/pdf/1904.02701.pdf).

```
@inproceedings{pang2019libra,
title={Libra R-CNN: Towards Balanced Learning for Object Detection},
author={Pang, Jiangmiao and Chen, Kai and Shi, Jianping and Feng, Huajun and Ouyang, Wanli and Dahua Lin},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
year={2019}
}
```

## Results and models

The results on COCO 2017val are shown in the below table. (results on test-dev are usually slightly higher than val)

| Architecture | Backbone | Style | Lr schd | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download |
|:---------:|:-------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|:-------:|:--------:|
| Faster R-CNN | R-50-FPN | pytorch | 1x | 4.2 | 0.375 | 12.0 | 38.6 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/libra_rcnn/libra_faster_rcnn_r50_fpn_1x_20190525-c8c06833.pth) |
| Fast R-CNN | R-50-FPN | pytorch | 1x | 3.7 | 0.272 | 16.3 | 38.5 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/libra_rcnn/libra_fast_rcnn_r50_fpn_1x_20190525-a43f88b5.pth) |
| Faster R-CNN | R-101-FPN | pytorch | 1x | 6.0 | 0.495 | 10.4 | 40.3 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/libra_rcnn/libra_faster_rcnn_r101_fpn_1x_20190525-94e94051.pth) |
| Faster R-CNN | X-101-64x4d-FPN | pytorch | 1x | 10.1 | 1.050 | 6.8 | 42.7 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/libra_rcnn/libra_faster_rcnn_x101_64x4d_fpn_1x_20190525-359c134a.pth) |
| RetinaNet | R-50-FPN | pytorch | 1x | 3.7 | 0.328 | 11.8 | 37.7 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/libra_rcnn/libra_retinanet_r50_fpn_1x_20190525-ead2a6bb.pth) |
144 changes: 144 additions & 0 deletions configs/libra_rcnn/libra_fast_rcnn_r50_fpn_1x.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# model settings
model = dict(
type='FastRCNN',
pretrained='modelzoo://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=[
dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
dict(
type='BFP',
in_channels=256,
num_levels=5,
refine_level=2,
refine_type='non_local')
],
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(
type='BalancedL1Loss',
alpha=0.5,
gamma=1.5,
beta=1.0,
loss_weight=1.0)))
# model training and testing settings
train_cfg = dict(
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
type='CombinedSampler',
num=512,
pos_fraction=0.25,
add_gt_as_proposals=True,
pos_sampler=dict(type='InstanceBalancedPosSampler'),
neg_sampler=dict(
type='IoUBalancedNegSampler',
floor=-1,
floor_thr=0,
num_bins=3)),
pos_weight=-1,
debug=False))
test_cfg = dict(
rcnn=dict(
score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100))
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
data = dict(
imgs_per_gpu=2,
workers_per_gpu=0,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
proposal_file=data_root +
'libra_proposals/rpn_r50_fpn_1x_train2017.pkl',
flip_ratio=0.5,
with_mask=False,
with_crowd=True,
with_label=True),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
proposal_file=data_root + 'libra_proposals/rpn_r50_fpn_1x_val2017.pkl',
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_crowd=True,
with_label=True),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
proposal_file=data_root + 'libra_proposals/rpn_r50_fpn_1x_val2017.pkl',
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_label=False,
test_mode=True))
# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/libra_fast_rcnn_r50_fpn_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
185 changes: 185 additions & 0 deletions configs/libra_rcnn/libra_faster_rcnn_r101_fpn_1x.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# model settings
model = dict(
type='FasterRCNN',
pretrained='modelzoo://resnet101',
backbone=dict(
type='ResNet',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=[
dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
dict(
type='BFP',
in_channels=256,
num_levels=5,
refine_level=2,
refine_type='non_local')
],
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(
type='BalancedL1Loss',
alpha=0.5,
gamma=1.5,
beta=1.0,
loss_weight=1.0)))
# model training and testing settings
train_cfg = dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=5,
add_gt_as_proposals=False),
allowed_border=-1,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_across_levels=False,
nms_pre=2000,
nms_post=2000,
max_num=2000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
type='CombinedSampler',
num=512,
pos_fraction=0.25,
add_gt_as_proposals=True,
pos_sampler=dict(type='InstanceBalancedPosSampler'),
neg_sampler=dict(
type='IoUBalancedNegSampler',
floor=-1,
floor_thr=0,
num_bins=3)),
pos_weight=-1,
debug=False))
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=1000,
nms_post=1000,
max_num=1000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100)
# soft-nms is also supported for rcnn testing
# e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05)
)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0.5,
with_mask=False,
with_crowd=True,
with_label=True),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_crowd=True,
with_label=True),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_label=False,
test_mode=True))
# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/libra_faster_rcnn_r101_fpn_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
Loading

0 comments on commit b581e19

Please sign in to comment.