Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
… into pytorch-1.0
  • Loading branch information
yhcao6 committed Jan 21, 2019
2 parents f3a939f + e9cb6fa commit 69f40c9
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 58 deletions.
6 changes: 3 additions & 3 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

- Linux (tested on Ubuntu 16.04 and CentOS 7.2)
- Python 3.4+
- PyTorch 0.4.1
- PyTorch 1.0
- Cython
- [mmcv](https://github.com/open-mmlab/mmcv)
- [mmcv](https://github.com/open-mmlab/mmcv) >= 0.2.2

### Install mmdetection

a. Install PyTorch 0.4.1 and torchvision following the [official instructions](https://pytorch.org/).
a. Install PyTorch 1.0 and torchvision following the [official instructions](https://pytorch.org/).

b. Clone the mmdetection repository.

Expand Down
28 changes: 18 additions & 10 deletions mmdet/core/loss/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,21 @@ def sigmoid_focal_loss(pred,
weight,
gamma=2.0,
alpha=0.25,
reduction='elementwise_mean'):
reduction='mean'):
pred_sigmoid = pred.sigmoid()
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
weight = (alpha * target + (1 - alpha) * (1 - target)) * weight
weight = weight * pt.pow(gamma)
return F.binary_cross_entropy_with_logits(
pred, target, weight, reduction=reduction)
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none') * weight
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()


def weighted_sigmoid_focal_loss(pred,
Expand All @@ -62,22 +70,22 @@ def mask_cross_entropy(pred, target, label):
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
pred_slice = pred[inds, label].squeeze(1)
return F.binary_cross_entropy_with_logits(
pred_slice, target, reduction='elementwise_mean')[None]
pred_slice, target, reduction='mean')[None]


def smooth_l1_loss(pred, target, beta=1.0, reduction='elementwise_mean'):
def smooth_l1_loss(pred, target, beta=1.0, reduction='mean'):
assert beta > 0
assert pred.size() == target.size() and target.numel() > 0
diff = torch.abs(pred - target)
loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
diff - 0.5 * beta)
reduction = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction == 0:
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction == 1:
elif reduction_enum == 1:
return loss.sum() / pred.numel()
elif reduction == 2:
elif reduction_enum == 2:
return loss.sum()


Expand Down
12 changes: 6 additions & 6 deletions mmdet/ops/roi_align/functions/roi_align.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from torch.autograd import Function, Variable
from torch.autograd import Function

from .. import roi_align_cuda

Expand Down Expand Up @@ -49,11 +49,11 @@ def backward(ctx, grad_output):

grad_input = grad_rois = None
if ctx.needs_input_grad[0]:
grad_input = Variable(
rois.new(batch_size, num_channels, data_height, data_width)
.zero_())
roi_align_cuda.backward(grad_output, rois, out_h, out_w,
spatial_scale, sample_num, grad_input)
grad_input = rois.new_zeros(batch_size, num_channels, data_height,
data_width)
roi_align_cuda.backward(grad_output.contiguous(), rois, out_h,
out_w, spatial_scale, sample_num,
grad_input)

return grad_input, grad_rois, None, None, None

Expand Down
2 changes: 1 addition & 1 deletion mmdet/ops/roi_align/src/roi_align_cuda.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/torch.h>
#include <torch/extension.h>

#include <cmath>
#include <vector>
Expand Down
19 changes: 3 additions & 16 deletions mmdet/ops/roi_align/src/roi_align_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>

using namespace at; // temporal fix for pytorch<=0.4.1 (see #9848)

#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
Expand Down Expand Up @@ -144,12 +142,7 @@ int ROIAlignForwardLaucher(const at::Tensor features, const at::Tensor rois,
sample_num, channels, height, width, pooled_height,
pooled_width, top_data);
}));
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
exit(-1);
}

THCudaCheck(cudaGetLastError());
return 1;
}

Expand Down Expand Up @@ -280,8 +273,7 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
at::Tensor bottom_grad) {
const int output_size = num_rois * pooled_height * pooled_width * channels;

// TODO: use AT_DISPATCH_FLOATING_TYPES_AND_HALF when atomicAdd is resolved
AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.type(), "ROIAlignLaucherBackward", ([&] {
const scalar_t *top_diff = top_grad.data<scalar_t>();
const scalar_t *rois_data = rois.data<scalar_t>();
Expand All @@ -297,11 +289,6 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
channels, height, width, pooled_height, pooled_width,
bottom_diff);
}));
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
exit(-1);
}

THCudaCheck(cudaGetLastError());
return 1;
}
11 changes: 5 additions & 6 deletions mmdet/ops/roi_pool/functions/roi_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ def forward(ctx, features, rois, out_size, spatial_scale):
num_channels = features.size(1)
num_rois = rois.size(0)
out_size = (num_rois, num_channels, out_h, out_w)
output = features.new_zeros(*out_size)

argmax = features.new_zeros(*out_size, dtype=torch.int)
output = features.new_zeros(out_size)
argmax = features.new_zeros(out_size, dtype=torch.int)
roi_pool_cuda.forward(features, rois, out_h, out_w, spatial_scale,
output, argmax)
ctx.spatial_scale = spatial_scale
Expand All @@ -46,9 +45,9 @@ def backward(ctx, grad_output):

grad_input = grad_rois = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.new(feature_size).zero_()
roi_pool_cuda.backward(grad_output, rois, argmax, spatial_scale,
grad_input)
grad_input = grad_output.new_zeros(feature_size)
roi_pool_cuda.backward(grad_output.contiguous(), rois, argmax,
spatial_scale, grad_input)

return grad_input, grad_rois, None, None

Expand Down
2 changes: 1 addition & 1 deletion mmdet/ops/roi_pool/src/roi_pool_cuda.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/torch.h>
#include <torch/extension.h>

#include <cmath>
#include <vector>
Expand Down
18 changes: 3 additions & 15 deletions mmdet/ops/roi_pool/src/roi_pool_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>

using namespace at; // temporal fix for pytorch<=0.4.1 (see #9848)

#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
Expand Down Expand Up @@ -100,11 +98,7 @@ int ROIPoolForwardLaucher(const at::Tensor features, const at::Tensor rois,
channels, height, width, pooled_h, pooled_w, top_data,
argmax_data);
}));
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
THCudaCheck(cudaGetLastError());
return 1;
}

Expand Down Expand Up @@ -139,8 +133,7 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
const int pooled_w, at::Tensor bottom_grad) {
const int output_size = num_rois * pooled_h * pooled_w * channels;

// TODO: use AT_DISPATCH_FLOATING_TYPES_AND_HALF when atomicAdd is resolved
AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.type(), "ROIPoolLaucherBackward", ([&] {
const scalar_t *top_diff = top_grad.data<scalar_t>();
const scalar_t *rois_data = rois.data<scalar_t>();
Expand All @@ -158,11 +151,6 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
scalar_t(spatial_scale), channels, height, width, pooled_h,
pooled_w, bottom_diff);
}));
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
exit(-1);
}

THCudaCheck(cudaGetLastError());
return 1;
}

0 comments on commit 69f40c9

Please sign in to comment.