Skip to content

Commit

Permalink
Merge pull request daniilidis-group#2 from daniilidis-group/devel
Browse files Browse the repository at this point in the history
Add Volta support
  • Loading branch information
nkolot authored Jul 14, 2018
2 parents ffc2762 + 3ece420 commit 661e4cc
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 166 deletions.
8 changes: 4 additions & 4 deletions neural_renderer/cuda/rasterize_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ std::vector<at::Tensor> forward_face_index_map_cuda(
at::Tensor weight_map,
at::Tensor depth_map,
at::Tensor face_inv_map,
at::Tensor lock,
at::Tensor faces_inv,
int image_size,
float near,
float far,
Expand Down Expand Up @@ -73,7 +73,7 @@ std::vector<at::Tensor> forward_face_index_map(
at::Tensor weight_map,
at::Tensor depth_map,
at::Tensor face_inv_map,
at::Tensor lock,
at::Tensor faces_inv,
int image_size,
float near,
float far,
Expand All @@ -86,10 +86,10 @@ std::vector<at::Tensor> forward_face_index_map(
CHECK_INPUT(weight_map);
CHECK_INPUT(depth_map);
CHECK_INPUT(face_inv_map);
CHECK_INPUT(lock);
CHECK_INPUT(faces_inv);

return forward_face_index_map_cuda(faces, face_index_map, weight_map,
depth_map, face_inv_map, lock,
depth_map, face_inv_map, faces_inv,
image_size, near, far,
return_rgb, return_alpha, return_depth);
}
Expand Down
292 changes: 140 additions & 152 deletions neural_renderer/cuda/rasterize_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,176 +19,151 @@ static __inline__ __device__ double atomicAdd(double* address, double val) {
}
#endif

// implementation of atomicExch for double input
// adapted from https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
__device__ double atomicExch(double* address, double val) {
unsigned long long int* address_as_ull =
(unsigned long long int*)address;
unsigned long long int old = *address_as_ull, assumed;

do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val));
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) } while (assumed != old);

} while (assumed != old);
return __longlong_as_double(old);
}

namespace{
template <typename scalar_t>
__global__ void forward_face_index_map_cuda_kernel(
__global__ void forward_face_index_map_cuda_kernel_1(
const scalar_t* __restrict__ faces,
int32_t* face_index_map,
scalar_t* weight_map,
scalar_t* depth_map,
scalar_t* face_inv_map,
int32_t* lock,
size_t batch_size,
size_t num_faces,
int image_size,
scalar_t near,
scalar_t far,
int return_rgb,
int return_alpha,
int return_depth) {
scalar_t* faces_inv,
int batch_size,
int num_faces,
int image_size) {
/* batch number, face, number, image size, face[v012][RGB] */
const int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= num_faces * batch_size) {
if (i >= batch_size * num_faces) {
return;
}
const int bn = i / num_faces;
const int fn = i % num_faces;
const int is = image_size;
const scalar_t* face = &faces[i * 9];

scalar_t* face_inv_g = &faces_inv[i * 9];

/* return if backside */
if ((face[7] - face[1]) * (face[3] - face[0]) < (face[4] - face[1]) * (face[6] - face[0]))
return;

/* pi[0], pi[1], pi[2] = leftmost, middle, rightmost points */
int pi[3];
if (face[0] < face[3]) {
if (face[6] < face[0])
pi[0] = 2;
else pi[0] = 0;
if (face[3] < face[6])
pi[2] = 2;
else pi[2] = 1;
} else {
if (face[6] < face[3])
pi[0] = 2;
else pi[0] = 1;
if (face[0] < face[6])
pi[2] = 2;
else pi[2] = 0;
}
for (int k = 0; k < 3; k++)
if (pi[0] != k && pi[2] != k)
pi[1] = k;

/* p[num][xyz]: x, y is normalized from [-1, 1] to [0, is - 1]. */
scalar_t p[3][3];

/* p[num][xy]: x, y is normalized from [-1, 1] to [0, is - 1]. */
float p[3][2];
for (int num = 0; num < 3; num++) {
for (int dim = 0; dim < 3; dim++) {
if (dim != 2) {
p[num][dim] = 0.5 * (face[3 * pi[num] + dim] * is + is - 1);
} else {
p[num][dim] = face[3 * pi[num] + dim];
}
for (int dim = 0; dim < 2; dim++) {
p[num][dim] = 0.5 * (face[3 * num + dim] * is + is - 1);
}
}
if (p[0][0] == p[2][0])
return; // line, not triangle


/* compute face_inv */
scalar_t face_inv[9] = {
float face_inv[9] = {
p[1][1] - p[2][1], p[2][0] - p[1][0], p[1][0] * p[2][1] - p[2][0] * p[1][1],
p[2][1] - p[0][1], p[0][0] - p[2][0], p[2][0] * p[0][1] - p[0][0] * p[2][1],
p[0][1] - p[1][1], p[1][0] - p[0][0], p[0][0] * p[1][1] - p[1][0] * p[0][1]};
scalar_t face_inv_denominator = (
float face_inv_denominator = (
p[2][0] * (p[0][1] - p[1][1]) +
p[0][0] * (p[1][1] - p[2][1]) +
p[1][0] * (p[2][1] - p[0][1]));
for (int k = 0; k < 9; k++)
for (int k = 0; k < 9; k++) {
face_inv[k] /= face_inv_denominator;
}
/* set to global memory */
for (int k = 0; k < 9; k++) {
face_inv_g[k] = face_inv[k];
}
}

template <typename scalar_t>
__global__ void forward_face_index_map_cuda_kernel_2(
const scalar_t* faces,
scalar_t* faces_inv,
int32_t* face_index_map,
scalar_t* weight_map,
scalar_t* depth_map,
scalar_t* face_inv_map,
int batch_size,
int num_faces,
int image_size,
scalar_t near,
scalar_t far,
int return_rgb,
int return_alpha,
int return_depth) {

const int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= batch_size * image_size * image_size) {
return;
}
const int is = image_size;
const int nf = num_faces;
const int bn = i / (is * is);
const int pn = i % (is * is);
const int yi = pn / is;
const int xi = pn % is;
const scalar_t yp = (2. * yi + 1 - is) / is;
const scalar_t xp = (2. * xi + 1 - is) / is;

/* from left to right */
// const int xi_min = min(max(ceil(p[0][0]), 0.), is - 1.);
// const int xi_max = max(min(p[2][0], is - 1.), 0.);
const int xi_min = max(ceil(p[0][0]), 0.);
const int xi_max = min(p[2][0], is - 1.);
for (int xi = xi_min; xi <= xi_max; xi++) {
/* compute yi_min and yi_max */
scalar_t yi1, yi2;
if (xi <= p[1][0]) {
if (p[1][0] - p[0][0] != 0) {
yi1 = (p[1][1] - p[0][1]) / (p[1][0] - p[0][0]) * (xi - p[0][0]) + p[0][1];
}
else {
yi1 = p[1][1];
}
}
else {
if (p[2][0] - p[1][0] != 0) {
yi1 = (p[2][1] - p[1][1]) / (p[2][0] - p[1][0]) * (xi - p[1][0]) + p[1][1];
}
else {
yi1 = p[1][1];
}
}
yi2 = (p[2][1] - p[0][1]) / (p[2][0] - p[0][0]) * (xi - p[0][0]) + p[0][1];
const scalar_t* face = &faces[bn * nf * 9] - 9;
scalar_t* face_inv = &faces_inv[bn * nf * 9] - 9;
scalar_t depth_min = far;
int face_index_min = -1;
scalar_t weight_min[3];
scalar_t face_inv_min[9];
for (int fn = 0; fn < nf; fn++) {
/* go to next face */
face += 9;
face_inv += 9;

/* from up to bottom */
int yi_min = max(0., ceil(min(yi1, yi2)));
int yi_max = min(max(yi1, yi2), is - 1.);
for (int yi = yi_min; yi <= yi_max; yi++) {
/* index in output buffers */
int index = bn * is * is + yi * is + xi;
// remove it after debugging
if (index > batch_size * is * is -1)
continue;
/* return if backside */
if ((face[7] - face[1]) * (face[3] - face[0]) < (face[4] - face[1]) * (face[6] - face[0]))
continue;

/* compute w = face_inv * p */
scalar_t w[3];
for (int k = 0; k < 3; k++)
w[k] = face_inv[3 * k + 0] * xi + face_inv[3 * k + 1] * yi + face_inv[3 * k + 2];
/* check [py, px] is inside the face */
if (((yp - face[1]) * (face[3] - face[0]) < (xp - face[0]) * (face[4] - face[1])) ||
((yp - face[4]) * (face[6] - face[3]) < (xp - face[3]) * (face[7] - face[4])) ||
((yp - face[7]) * (face[0] - face[6]) < (xp - face[6]) * (face[1] - face[7])))
continue;

/* sum(w) -> 1, 0 < w < 1 */
scalar_t w_sum = 0;
for (int k = 0; k < 3; k++) {
w[k] = min(max(w[k], 0.), 1.);
w_sum += w[k];
}
for (int k = 0; k < 3; k++)
w[k] /= w_sum;
/* compute w = face_inv * p */
scalar_t w[3];
w[0] = face_inv[3 * 0 + 0] * xi + face_inv[3 * 0 + 1] * yi + face_inv[3 * 0 + 2];
w[1] = face_inv[3 * 1 + 0] * xi + face_inv[3 * 1 + 1] * yi + face_inv[3 * 1 + 2];
w[2] = face_inv[3 * 2 + 0] * xi + face_inv[3 * 2 + 1] * yi + face_inv[3 * 2 + 2];

/* compute 1 / zp = sum(w / z) */
const scalar_t zp = 1. / (w[0] / p[0][2] + w[1] / p[1][2] + w[2] / p[2][2]);
// index = 2;
if (zp <= near || far <= zp)
continue;
/* sum(w) -> 1, 0 < w < 1 */
scalar_t w_sum = 0;
for (int k = 0; k < 3; k++) {
w[k] = min(max(w[k], 0.), 1.);
w_sum += w[k];
}
for (int k = 0; k < 3; k++) {
w[k] /= w_sum;
}
/* compute 1 / zp = sum(w / z) */
const scalar_t zp = 1. / (w[0] / face[2] + w[1] / face[5] + w[2] / face[8]);
if (zp <= near || far <= zp) {
continue;
}

/* lock and update */
bool locked = false;
do {
if ( locked = atomicCAS(&lock[index], 0, 1) == 0) {
if (zp <= depth_map[index]) {
depth_map[index] = zp;
face_index_map[index] = fn;
for (int k = 0; k < 3; k++)
atomicExch(&weight_map[3 * index + pi[k]], w[k]);
if (return_depth) {
for (int k = 0; k < 3; k++)
for (int l = 0; l < 3; l++)
atomicExch(
&face_inv_map[9 * index + 3 * pi[l] + k], face_inv[3 * l + k]);
}
}
atomicExch(&lock[index], 0);
/* check z-buffer */
if (zp < depth_min) {
depth_min = zp;
face_index_min = fn;
for (int k = 0; k < 3; k++) {
weight_min[k] = w[k];
}
if (return_depth) {
for (int k = 0; k < 9; k++) {
face_inv_min[k] = face_inv[k];
}
} while (!locked);
}
}
}

/* set to global memory */
if (0 <= face_index_min) {
depth_map[i] = depth_min;
face_index_map[i] = face_index_min;
for (int k = 0; k < 3; k++) {
weight_map[3 * i + k] = weight_min[k];
}
if (return_depth) {
for (int k = 0; k < 9; k++) {
face_inv_map[9 * i + k] = face_inv_min[k];
}
}
}
}
Expand Down Expand Up @@ -623,7 +598,7 @@ std::vector<at::Tensor> forward_face_index_map_cuda(
at::Tensor weight_map,
at::Tensor depth_map,
at::Tensor face_inv_map,
at::Tensor lock,
at::Tensor faces_inv,
int image_size,
float near,
float far,
Expand All @@ -634,34 +609,47 @@ std::vector<at::Tensor> forward_face_index_map_cuda(
const auto batch_size = faces.size(0);
const auto num_faces = faces.size(1);
const int threads = 512;
const dim3 blocks ((batch_size * num_faces - 1) / threads +1);
const dim3 blocks_1 ((batch_size * num_faces - 1) / threads +1);

AT_DISPATCH_FLOATING_TYPES(faces.type(), "forward_face_index_map_cuda_1", ([&] {
forward_face_index_map_cuda_kernel_1<scalar_t><<<blocks_1, threads>>>(
faces.data<scalar_t>(),
faces_inv.data<scalar_t>(),
batch_size,
num_faces,
image_size);
}));

cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
printf("Error in forward_face_index_map_1: %s\n", cudaGetErrorString(err));

AT_DISPATCH_FLOATING_TYPES(faces.type(), "forward_face_index_map_cuda", ([&] {
forward_face_index_map_cuda_kernel<scalar_t><<<blocks, threads>>>(
const dim3 blocks_2 ((batch_size * image_size * image_size - 1) / threads +1);
AT_DISPATCH_FLOATING_TYPES(faces.type(), "forward_face_index_map_cuda_2", ([&] {
forward_face_index_map_cuda_kernel_2<scalar_t><<<blocks_2, threads>>>(
faces.data<scalar_t>(),
faces_inv.data<scalar_t>(),
face_index_map.data<int32_t>(),
weight_map.data<scalar_t>(),
depth_map.data<scalar_t>(),
face_inv_map.data<scalar_t>(),
lock.data<int32_t>(),
batch_size,
num_faces,
image_size,
(int) batch_size,
(int) num_faces,
(int) image_size,
(scalar_t) near,
(scalar_t) far,
return_rgb,
return_alpha,
return_depth);
}));

cudaError_t err = cudaGetLastError();
err = cudaGetLastError();
if (err != cudaSuccess)
printf("Error in forward_face_index_map: %s\n", cudaGetErrorString(err));
printf("Error in forward_face_index_map_2: %s\n", cudaGetErrorString(err));
return {face_index_map, weight_map, depth_map, face_inv_map};
}

std::vector<at::Tensor> forward_texture_sampling_cuda(
at::Tensor faces,
std::vector<at::Tensor> forward_texture_sampling_cuda( at::Tensor faces,
at::Tensor textures,
at::Tensor face_index_map,
at::Tensor weight_map,
Expand Down
Loading

0 comments on commit 661e4cc

Please sign in to comment.