-
Notifications
You must be signed in to change notification settings - Fork 165
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
50 changed files
with
1,911 additions
and
311 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
### MMCV | ||
- [ ] Implement the attr 'get' of 'Config' | ||
- [ ] Config bugs: None type to '{}' with addict | ||
- [ ] Default logger should be only with gpu0 | ||
- [ ] Unit Test: mmcv and mmcv.torchpack | ||
|
||
|
||
### MMDetection | ||
|
||
#### Basic | ||
- [ ] Implement training function without distributed | ||
- [ ] Verify nccl/nccl2/gloo | ||
- [ ] Replace UGLY code: params plug in 'args' to reach a global flow | ||
- [ ] Replace 'print' by 'logger' | ||
|
||
|
||
#### Testing | ||
- [ ] Implement distributed testing | ||
- [ ] Implement single gpu testing | ||
|
||
|
||
#### Refactor | ||
- [ ] Re-consider params names | ||
- [ ] Refactor functions in 'core' | ||
- [ ] Merge single test & aug test as one function, so as other redundancy | ||
|
||
#### New features | ||
- [ ] Plug loss params into Config | ||
- [ ] Multi-head communication |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,9 @@ | ||
from .anchor_generator import * | ||
from .train_engine import * | ||
from .test_engine import * | ||
from .rpn_ops import * | ||
from .bbox_ops import * | ||
from .mask_ops import * | ||
from .losses import * | ||
from .eval import * | ||
from .nn import * | ||
from .targets import * | ||
from .post_processing import * | ||
from .utils import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,16 @@ | ||
from .geometry import bbox_overlaps | ||
from .sampling import (random_choice, bbox_assign, bbox_assign_via_overlaps, | ||
bbox_sampling, sample_positives, sample_negatives) | ||
bbox_sampling, sample_positives, sample_negatives, | ||
sample_proposals) | ||
from .transforms import (bbox_transform, bbox_transform_inv, bbox_flip, | ||
bbox_mapping, bbox_mapping_back, bbox2roi, roi2bbox) | ||
bbox_mapping, bbox_mapping_back, bbox2roi, roi2bbox, | ||
bbox2result) | ||
from .bbox_target import bbox_target | ||
|
||
__all__ = [ | ||
'bbox_overlaps', 'random_choice', 'bbox_assign', | ||
'bbox_assign_via_overlaps', 'bbox_sampling', 'sample_positives', | ||
'sample_negatives', 'bbox_transform', 'bbox_transform_inv', 'bbox_flip', | ||
'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', 'roi2bbox' | ||
'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', 'roi2bbox', 'bbox2result', | ||
'bbox_target', 'sample_proposals' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import mmcv | ||
import torch | ||
|
||
from .geometry import bbox_overlaps | ||
from .transforms import bbox_transform, bbox_transform_inv | ||
|
||
|
||
def bbox_target(pos_proposals_list, | ||
neg_proposals_list, | ||
pos_gt_bboxes_list, | ||
pos_gt_labels_list, | ||
cfg, | ||
reg_num_classes=1, | ||
target_means=[.0, .0, .0, .0], | ||
target_stds=[1.0, 1.0, 1.0, 1.0], | ||
return_list=False): | ||
img_per_gpu = len(pos_proposals_list) | ||
all_labels = [] | ||
all_label_weights = [] | ||
all_bbox_targets = [] | ||
all_bbox_weights = [] | ||
for img_id in range(img_per_gpu): | ||
pos_proposals = pos_proposals_list[img_id] | ||
neg_proposals = neg_proposals_list[img_id] | ||
pos_gt_bboxes = pos_gt_bboxes_list[img_id] | ||
pos_gt_labels = pos_gt_labels_list[img_id] | ||
debug_img = debug_imgs[img_id] if cfg.debug else None | ||
labels, label_weights, bbox_targets, bbox_weights = proposal_target_single( | ||
pos_proposals, neg_proposals, pos_gt_bboxes, pos_gt_labels, | ||
reg_num_classes, cfg, target_means, target_stds) | ||
all_labels.append(labels) | ||
all_label_weights.append(label_weights) | ||
all_bbox_targets.append(bbox_targets) | ||
all_bbox_weights.append(bbox_weights) | ||
|
||
if return_list: | ||
return all_labels, all_label_weights, all_bbox_targets, all_bbox_weights | ||
|
||
labels = torch.cat(all_labels, 0) | ||
label_weights = torch.cat(all_label_weights, 0) | ||
bbox_targets = torch.cat(all_bbox_targets, 0) | ||
bbox_weights = torch.cat(all_bbox_weights, 0) | ||
return labels, label_weights, bbox_targets, bbox_weights | ||
|
||
|
||
def proposal_target_single(pos_proposals, | ||
neg_proposals, | ||
pos_gt_bboxes, | ||
pos_gt_labels, | ||
reg_num_classes, | ||
cfg, | ||
target_means=[.0, .0, .0, .0], | ||
target_stds=[1.0, 1.0, 1.0, 1.0]): | ||
num_pos = pos_proposals.size(0) | ||
num_neg = neg_proposals.size(0) | ||
num_samples = num_pos + num_neg | ||
labels = pos_proposals.new_zeros(num_samples, dtype=torch.long) | ||
label_weights = pos_proposals.new_zeros(num_samples) | ||
bbox_targets = pos_proposals.new_zeros(num_samples, 4) | ||
bbox_weights = pos_proposals.new_zeros(num_samples, 4) | ||
if num_pos > 0: | ||
labels[:num_pos] = pos_gt_labels | ||
pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight | ||
label_weights[:num_pos] = pos_weight | ||
pos_bbox_targets = bbox_transform(pos_proposals, pos_gt_bboxes, | ||
target_means, target_stds) | ||
bbox_targets[:num_pos, :] = pos_bbox_targets | ||
bbox_weights[:num_pos, :] = 1 | ||
if num_neg > 0: | ||
label_weights[-num_neg:] = 1.0 | ||
if reg_num_classes > 1: | ||
bbox_targets, bbox_weights = expand_target(bbox_targets, bbox_weights, | ||
labels, reg_num_classes) | ||
|
||
return labels, label_weights, bbox_targets, bbox_weights | ||
|
||
|
||
def expand_target(bbox_targets, bbox_weights, labels, num_classes): | ||
bbox_targets_expand = bbox_targets.new_zeros((bbox_targets.size(0), | ||
4 * num_classes)) | ||
bbox_weights_expand = bbox_weights.new_zeros((bbox_weights.size(0), | ||
4 * num_classes)) | ||
for i in torch.nonzero(labels > 0).squeeze(-1): | ||
start, end = labels[i] * 4, (labels[i] + 1) * 4 | ||
bbox_targets_expand[i, start:end] = bbox_targets[i, :] | ||
bbox_weights_expand[i, start:end] = bbox_weights[i, :] | ||
return bbox_targets_expand, bbox_weights_expand |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from .losses import ( | ||
weighted_nll_loss, weighted_cross_entropy, weighted_binary_cross_entropy, | ||
sigmoid_focal_loss, weighted_sigmoid_focal_loss, mask_cross_entropy, | ||
weighted_mask_cross_entropy, smooth_l1_loss, weighted_smoothl1, accuracy) | ||
|
||
__all__ = [ | ||
'weighted_nll_loss', 'weighted_cross_entropy', | ||
'weighted_binary_cross_entropy', 'sigmoid_focal_loss', | ||
'weighted_sigmoid_focal_loss', 'mask_cross_entropy', | ||
'weighted_mask_cross_entropy', 'smooth_l1_loss', 'weighted_smoothl1', | ||
'accuracy' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# TODO merge naive and weighted loss to one function. | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
from ..bbox_ops import bbox_transform_inv, bbox_overlaps | ||
|
||
|
||
def weighted_nll_loss(pred, label, weight, ave_factor=None): | ||
if ave_factor is None: | ||
ave_factor = max(torch.sum(weight > 0).float().item(), 1.) | ||
raw = F.nll_loss(pred, label, size_average=False, reduce=False) | ||
return torch.sum(raw * weight)[None] / ave_factor | ||
|
||
|
||
def weighted_cross_entropy(pred, label, weight, ave_factor=None): | ||
if ave_factor is None: | ||
ave_factor = max(torch.sum(weight > 0).float().item(), 1.) | ||
raw = F.cross_entropy(pred, label, size_average=False, reduce=False) | ||
return torch.sum(raw * weight)[None] / ave_factor | ||
|
||
|
||
def weighted_binary_cross_entropy(pred, label, weight, ave_factor=None): | ||
if ave_factor is None: | ||
ave_factor = max(torch.sum(weight > 0).float().item(), 1.) | ||
return F.binary_cross_entropy_with_logits( | ||
pred, label.float(), weight.float(), | ||
size_average=False)[None] / ave_factor | ||
|
||
|
||
def sigmoid_focal_loss(pred, | ||
target, | ||
weight, | ||
gamma=2.0, | ||
alpha=0.25, | ||
size_average=True): | ||
pred_sigmoid = pred.sigmoid() | ||
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) | ||
weight = (alpha * target + (1 - alpha) * (1 - target)) * weight | ||
weight = weight * pt.pow(gamma) | ||
return F.binary_cross_entropy_with_logits( | ||
pred, target, weight, size_average=size_average) | ||
|
||
|
||
def weighted_sigmoid_focal_loss(pred, | ||
target, | ||
weight, | ||
gamma=2.0, | ||
alpha=0.25, | ||
ave_factor=None, | ||
num_classes=80): | ||
if ave_factor is None: | ||
ave_factor = torch.sum(weight > 0).float().item() / num_classes + 1e-6 | ||
return sigmoid_focal_loss( | ||
pred, target, weight, gamma=gamma, alpha=alpha, | ||
size_average=False)[None] / ave_factor | ||
|
||
|
||
def mask_cross_entropy(pred, target, label): | ||
num_rois = pred.size()[0] | ||
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) | ||
pred_slice = pred[inds, label].squeeze(1) | ||
return F.binary_cross_entropy_with_logits( | ||
pred_slice, target, size_average=True)[None] | ||
|
||
|
||
def weighted_mask_cross_entropy(pred, target, weight, label): | ||
num_rois = pred.size()[0] | ||
num_samples = torch.sum(weight > 0).float().item() + 1e-6 | ||
assert num_samples >= 1 | ||
inds = torch.arange(0, num_rois).long().cuda() | ||
pred_slice = pred[inds, label].squeeze(1) | ||
return F.binary_cross_entropy_with_logits( | ||
pred_slice, target, weight, size_average=False)[None] / num_samples | ||
|
||
|
||
def smooth_l1_loss(pred, target, beta=1.0, size_average=True, reduce=True): | ||
assert beta > 0 | ||
assert pred.size() == target.size() and target.numel() > 0 | ||
diff = torch.abs(pred - target) | ||
loss = torch.where(diff < beta, 0.5 * diff * diff / beta, | ||
diff - 0.5 * beta) | ||
if size_average: | ||
loss /= pred.numel() | ||
if reduce: | ||
loss = loss.sum() | ||
return loss | ||
|
||
|
||
def weighted_smoothl1(pred, target, weight, beta=1.0, ave_factor=None): | ||
if ave_factor is None: | ||
ave_factor = torch.sum(weight > 0).float().item() / 4 + 1e-6 | ||
loss = smooth_l1_loss(pred, target, beta, size_average=False, reduce=False) | ||
return torch.sum(loss * weight)[None] / ave_factor | ||
|
||
|
||
def accuracy(pred, target, topk=1): | ||
if isinstance(topk, int): | ||
topk = (topk, ) | ||
return_single = True | ||
|
||
maxk = max(topk) | ||
_, pred_label = pred.topk(maxk, 1, True, True) | ||
pred_label = pred_label.t() | ||
correct = pred_label.eq(target.view(1, -1).expand_as(pred_label)) | ||
|
||
res = [] | ||
for k in topk: | ||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||
res.append(correct_k.mul_(100.0 / pred.size(0))) | ||
return res[0] if return_single else res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import torch | ||
import numpy as np | ||
|
||
from .segms import polys_to_mask_wrt_box | ||
|
||
|
||
def mask_target(pos_proposals_list, pos_assigned_gt_inds_list, gt_polys_list, | ||
img_meta, cfg): | ||
cfg_list = [cfg for _ in range(len(pos_proposals_list))] | ||
img_metas = [img_meta for _ in range(len(pos_proposals_list))] | ||
mask_targets = map(mask_target_single, pos_proposals_list, | ||
pos_assigned_gt_inds_list, gt_polys_list, img_metas, | ||
cfg_list) | ||
mask_targets = torch.cat(tuple(mask_targets), dim=0) | ||
return mask_targets | ||
|
||
|
||
def mask_target_single(pos_proposals, pos_assigned_gt_inds, gt_polys, | ||
img_meta, cfg): | ||
|
||
mask_size = cfg.mask_size | ||
num_pos = pos_proposals.size(0) | ||
mask_targets = pos_proposals.new_zeros((num_pos, mask_size, mask_size)) | ||
if num_pos > 0: | ||
pos_proposals = pos_proposals.cpu().numpy() | ||
pos_assigned_gt_inds = pos_assigned_gt_inds.cpu().numpy() | ||
scale_factor = img_meta['scale_factor'][0].cpu().numpy() | ||
for i in range(num_pos): | ||
bbox = pos_proposals[i, :] / scale_factor | ||
polys = gt_polys[pos_assigned_gt_inds[i]] | ||
mask = polys_to_mask_wrt_box(polys, bbox, mask_size) | ||
mask = np.array(mask > 0, dtype=np.float32) | ||
mask_targets[i, ...] = torch.from_numpy(mask).to( | ||
mask_targets.device) | ||
return mask_targets |
Oops, something went wrong.