Skip to content

Commit

Permalink
fixed projection bug
Browse files Browse the repository at this point in the history
  • Loading branch information
nkolot committed Sep 18, 2018
1 parent 29a7615 commit 4872a92
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 36 deletions.
4 changes: 2 additions & 2 deletions neural_renderer/cuda/create_texture_image_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ at::Tensor create_texture_image_cuda(

// C++ interface

#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERT(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)


Expand Down
4 changes: 2 additions & 2 deletions neural_renderer/cuda/load_textures_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ at::Tensor load_textures_cuda(

// C++ interface

#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERT(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)


Expand Down
4 changes: 2 additions & 2 deletions neural_renderer/cuda/rasterize_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ at::Tensor backward_depth_map_cuda(

// C++ interface

#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERT(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

std::vector<at::Tensor> forward_face_index_map(
Expand Down
2 changes: 2 additions & 0 deletions neural_renderer/look.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def look(vertices, eye, direction=[0, 1, 0], up=None):
elif torch.is_tensor(eye):
eye = eye.to(device)

if up is None:
up = torch.cuda.FloatTensor([0, 1, 0])
if eye.ndimension() == 1:
eye = eye[None, :]
if direction.ndimension() == 1:
Expand Down
27 changes: 18 additions & 9 deletions neural_renderer/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,23 @@
import torch


def projection(vertices, P, dist_coeffs, orig_size):
def projection(vertices, K, R, t, dist_coeffs, orig_size, eps=1e-9):
'''
Calculate projective transformation of vertices given a projection matrix
P: 3x4 projection matrix
Input parameters:
K: batch_size * 3 * 3 intrinsic camera matrix
R, t: batch_size * 3 * 3, batch_size * 1 * 3 extrinsic calibration parameters
dist_coeffs: vector of distortion coefficients
orig_size: original size of image captured by the camera
Returns: For each point [X,Y,Z] in world coordinates [u,v,z] where u,v are the coordinates of the projection in
pixels and z is the depth
'''
vertices = torch.cat([vertices, torch.ones_like(vertices[:, :, None, 0])], dim=-1)
vertices = torch.bmm(vertices, P.transpose(2,1))

# instead of P*x we compute x'*P'
vertices = torch.matmul(vertices, R.transpose(2,1)) + t
x, y, z = vertices[:, :, 0], vertices[:, :, 1], vertices[:, :, 2]
x_ = x / (z + 1e-5)
y_ = y / (z + 1e-5)
x_ = x / (z + eps)
y_ = y / (z + eps)

# Get distortion coefficients from vector
k1 = dist_coeffs[:, None, 0]
Expand All @@ -27,7 +32,11 @@ def projection(vertices, P, dist_coeffs, orig_size):
r = torch.sqrt(x_ ** 2 + y_ ** 2)
x__ = x_*(1 + k1*(r**2) + k2*(r**4) + k3*(r**6)) + 2*p1*x_*y_ + p2*(r**2 + 2*x_**2)
y__ = y_*(1 + k1*(r**2) + k2*(r**4) + k3 *(r**6)) + p1*(r**2 + 2*y_**2) + 2*p2*x_*y_
x__ = 2 * (x__ - orig_size / 2.) / orig_size
y__ = 2 * (y__ - orig_size / 2.) / orig_size
vertices = torch.stack([x__,y__,z], dim=-1)
vertices = torch.stack([x__, y__, torch.ones_like(z)], dim=-1)
vertices = torch.matmul(vertices, K.transpose(1,2))
u, v = vertices[:, :, 0], vertices[:, :, 1]
# map u,v from [0, img_size] to [-1, 1] to use by the renderer
u = 2 * (u - orig_size / 2.) / orig_size
v = 2 * (v - orig_size / 2.) / orig_size
vertices = torch.stack([u, v, z], dim=-1)
return vertices
7 changes: 3 additions & 4 deletions neural_renderer/rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def forward(ctx, faces, textures, image_size, near, far, eps, background_color,
textures = torch.cuda.FloatTensor(1).fill_(0)
ctx.texture_size = None


face_index_map = torch.cuda.IntTensor(ctx.batch_size, ctx.image_size, ctx.image_size).fill_(-1)
weight_map = torch.cuda.FloatTensor(ctx.batch_size, ctx.image_size, ctx.image_size, 3).fill_(0.0)
depth_map = torch.cuda.FloatTensor(ctx.batch_size, ctx.image_size, ctx.image_size).fill_(ctx.far)
Expand All @@ -67,7 +68,6 @@ def forward(ctx, faces, textures, image_size, near, far, eps, background_color,
else:
face_inv_map = torch.cuda.FloatTensor(1).fill_(0)


face_index_map, weight_map, depth_map, face_inv_map =\
RasterizeFunction.forward_face_index_map(ctx, faces, face_index_map,
weight_map, depth_map,
Expand Down Expand Up @@ -108,10 +108,9 @@ def backward(ctx, grad_rgb_map, grad_alpha_map, grad_depth_map):
ctx.saved_tensors
# initialize output buffers
# no need for explicit allocation of cuda.FloatTensor because zeros_like does it automatically
grad_faces = torch.zeros_like(faces, dtype=torch.float32).to(ctx.device).contiguous()
grad_faces = torch.zeros_like(faces, dtype=torch.float32).to(ctx.device).contiguous()
grad_faces = torch.zeros_like(faces, dtype=torch.float32)
if ctx.return_rgb:
grad_textures = torch.zeros_like(textures, dtype=torch.float32).to(ctx.device).contiguous()
grad_textures = torch.zeros_like(textures, dtype=torch.float32)
else:
grad_textures = torch.cuda.FloatTensor(1).fill_(0.0)

Expand Down
76 changes: 59 additions & 17 deletions neural_renderer/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class Renderer(nn.Module):
def __init__(self, image_size=256, anti_aliasing=True, background_color=[0,0,0],
fill_back=True, camera_mode='projection',
P=None, dist_coeffs=None, orig_size=1024,
K=None, R=None, t=None, dist_coeffs=None, orig_size=1024,
perspective=True, viewing_angle=30, camera_direction=[0,0,1],
near=0.1, far=100,
light_intensity_ambient=0.5, light_intensity_directional=0.5,
Expand All @@ -27,14 +27,24 @@ def __init__(self, image_size=256, anti_aliasing=True, background_color=[0,0,0],
# camera
self.camera_mode = camera_mode
if self.camera_mode == 'projection':
self.P = P
if isinstance(self.P, numpy.ndarray):
self.P = torch.from_numpy(self.P).cuda()
if self.P is None or P.ndimension() != 3 or self.P.shape[1] != 3 or self.P.shape[2] != 4:
raise ValueError('You need to provide a valid (batch_size)x3x4 projection matrix')
self.K = K
self.R = R
self.t = t
if isinstance(self.K, numpy.ndarray):
self.K = torch.cuda.FloatTensor(self.K)
if isinstance(self.R, numpy.ndarray):
self.R = torch.cuda.FloatTensor(self.R)
if isinstance(self.t, numpy.ndarray):
self.t = torch.cuda.FloatTensor(self.t)
# if self.K is None or self.K.ndimension() != 3 or self.K.shape[1] != 3 or self.K.shape[2] != 3:
# raise ValueError('You need to provide a valid (batch_size)x3x3 intrinsic camera matrix')
# if self.R is None or self.R.ndimension() != 3 or self.R.shape[1] != 3 or self.R.shape[2] != 3:
# raise ValueError('You need to provide a valid (batch_size)x3x3 rotation matrix')
# if self.t is None or self.t.ndimension() != 2 or self.t.shape[1] != 3:
# raise ValueError('You need to provide a valid (batch_size)x3 translation vector matrix')
self.dist_coeffs = dist_coeffs
if dist_coeffs is None:
self.dist_coeffs = torch.cuda.FloatTensor([[0., 0., 0., 0., 0.]]).repeat(P.shape[0], 1)
self.dist_coeffs = torch.cuda.FloatTensor([[0., 0., 0., 0., 0.]])
self.orig_size = orig_size
elif self.camera_mode in ['look', 'look_at']:
self.perspective = perspective
Expand All @@ -58,22 +68,23 @@ def __init__(self, image_size=256, anti_aliasing=True, background_color=[0,0,0],
# rasterization
self.rasterizer_eps = 1e-3

def forward(self, vertices, faces, textures=None, mode=None):
def forward(self, vertices, faces, textures=None, mode=None, K=None, R=None, t=None, dist_coeffs=None, orig_size=None):
'''
Implementation of forward rendering method
The old API is preserved for back-compatibility with the Chainer implementation
'''

if mode is None:
return self.render(vertices, faces, textures)
return self.render(vertices, faces, textures, K, R, t, dist_coeffs, orig_size)
elif mode == 'silhouettes':
return self.render_silhouettes(vertices, faces)
return self.render_silhouettes(vertices, faces, K, R, t, dist_coeffs, orig_size)
elif mode == 'depth':
return self.render_depth(vertices, faces)
return self.render_depth(vertices, faces, K, R, t, dist_coeffs, orig_size)
else:
raise ValueError("mode should be one of None, 'silhouettes' or 'depth'")

def render_silhouettes(self, vertices, faces):
def render_silhouettes(self, vertices, faces, K=None, R=None, t=None, dist_coeffs=None, orig_size=None):

# fill back
if self.fill_back:
faces = torch.cat((faces, faces[:, :, list(reversed(range(faces.shape[-1])))]), dim=1)
Expand All @@ -90,14 +101,25 @@ def render_silhouettes(self, vertices, faces):
if self.perspective:
vertices = nr.perspective(vertices, angle=self.viewing_angle)
elif self.camera_mode == 'projection':
vertices = nr.projection(vertices, self.P, self.dist_coeffs, self.orig_size)
if K is None:
K = self.K
if R is None:
R = self.R
if t is None:
t = self.t
if dist_coeffs is None:
dist_coeffs = self.dist_coeffs
if orig_size is None:
orig_size = self.orig_size
vertices = nr.projection(vertices, K, R, t, dist_coeffs, orig_size)

# rasterization
faces = nr.vertices_to_faces(vertices, faces)
images = nr.rasterize_silhouettes(faces, self.image_size, self.anti_aliasing)
return images

def render_depth(self, vertices, faces):
def render_depth(self, vertices, faces, K=None, R=None, t=None, dist_coeffs=None, orig_size=None):

# fill back
if self.fill_back:
faces = torch.cat((faces, faces[:, :, list(reversed(range(faces.shape[-1])))]), dim=1).detach()
Expand All @@ -114,14 +136,24 @@ def render_depth(self, vertices, faces):
if self.perspective:
vertices = nr.perspective(vertices, angle=self.viewing_angle)
elif self.camera_mode == 'projection':
vertices = nr.projection(vertices, self.P, self.dist_coeffs, self.orig_size)
if K is None:
K = self.K
if R is None:
R = self.R
if t is None:
t = self.t
if dist_coeffs is None:
dist_coeffs = self.dist_coeffs
if orig_size is None:
orig_size = self.orig_size
vertices = nr.projection(vertices, K, R, t, dist_coeffs, orig_size)

# rasterization
faces = nr.vertices_to_faces(vertices, faces)
images = nr.rasterize_depth(faces, self.image_size, self.anti_aliasing)
return images

def render(self, vertices, faces, textures):
def render(self, vertices, faces, textures, K=None, R=None, t=None, dist_coeffs=None, orig_size=None):
# fill back
if self.fill_back:
faces = torch.cat((faces, faces[:, :, list(reversed(range(faces.shape[-1])))]), dim=1).detach()
Expand Down Expand Up @@ -150,7 +182,17 @@ def render(self, vertices, faces, textures):
if self.perspective:
vertices = nr.perspective(vertices, angle=self.viewing_angle)
elif self.camera_mode == 'projection':
vertices = nr.projection(vertices, self.P, self.dist_coeffs, self.orig_size)
if K is None:
K = self.K
if R is None:
R = self.R
if t is None:
t = self.t
if dist_coeffs is None:
dist_coeffs = self.dist_coeffs
if orig_size is None:
orig_size = self.orig_size
vertices = nr.projection(vertices, K, R, t, dist_coeffs, orig_size)

# rasterization
faces = nr.vertices_to_faces(vertices, faces)
Expand Down

0 comments on commit 4872a92

Please sign in to comment.