#!/usr/bin/env python from __future__ import absolute_import from __future__ import print_function from __future__ import division import math import torch from torch import nn from torch.autograd import Function import torch.nn.functional as F from torch.nn.modules.utils import _pair from torch.autograd.function import once_differentiable import _ext as _backend import time import numpy as np class _DCNv2(Function): @staticmethod def forward(ctx, input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups): ctx.stride = _pair(stride) ctx.padding = _pair(padding) ctx.dilation = _pair(dilation) ctx.kernel_size = _pair(weight.shape[2:4]) ctx.deformable_groups = deformable_groups output = _backend.dcn_v2_forward(input, weight, bias, offset, mask, ctx.kernel_size[0], ctx.kernel_size[1], ctx.stride[0], ctx.stride[1], ctx.padding[0], ctx.padding[1], ctx.dilation[0], ctx.dilation[1], ctx.deformable_groups) ctx.save_for_backward(input, offset, mask, weight, bias) return output @staticmethod @once_differentiable def backward(ctx, grad_output): input, offset, mask, weight, bias = ctx.saved_tensors grad_input, grad_offset, grad_mask, grad_weight, grad_bias = \ _backend.dcn_v2_backward(input, weight, bias, offset, mask, grad_output, ctx.kernel_size[0], ctx.kernel_size[1], ctx.stride[0], ctx.stride[1], ctx.padding[0], ctx.padding[1], ctx.dilation[0], ctx.dilation[1], ctx.deformable_groups) return grad_input, grad_offset, grad_mask, grad_weight, grad_bias,\ None, None, None, None, dcn_v2_conv = _DCNv2.apply class DCNv2(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, deformable_groups=1): super(DCNv2, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = _pair(kernel_size) self.stride = _pair(stride) self.padding = _pair(padding) self.dilation = _pair(dilation) self.deformable_groups = deformable_groups self.weight = nn.Parameter(torch.Tensor( out_channels, in_channels, *self.kernel_size)) self.bias = nn.Parameter(torch.Tensor(out_channels)) self.reset_parameters() def reset_parameters(self): nn.init.kaiming_uniform_(self.weight, mode="fan_in", nonlinearity="relu") print("Set DCN Parameter.") self.bias.data.zero_() def forward(self, input, offset, mask): assert 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \ offset.shape[1] assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \ mask.shape[1] return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, self.deformable_groups) class DCNv1(DCNv2): def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, deformable_groups=1): super(DCNv1, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, deformable_groups) channels_ = self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1] self.conv_offset_mask = nn.Sequential( nn.Conv2d(in_channels,channels_,kernel_size=3,padding=1,bias=True) ) def forward(self, x): o = x.detach() offset = self.conv_offset_mask(o) o1, o2 = torch.chunk(offset, 2, dim=1) mask = torch.ones_like(o1) return dcn_v2_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, self.deformable_groups)