Skip to content

Commit

Permalink
working with instances
Browse files Browse the repository at this point in the history
  • Loading branch information
Bernhard Kerbl committed Jun 19, 2023
1 parent 3851457 commit e41a365
Showing 1 changed file with 25 additions and 17 deletions.
42 changes: 25 additions & 17 deletions diff_gaussian_rasterization/rasterizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from . import _C

def rasterize_gaussians(
instance,
means3D,
means2D,
sh,
Expand All @@ -16,6 +17,7 @@ def rasterize_gaussians(
rasterizer_state
):
return _RasterizeGaussians.apply(
instance,
means3D,
means2D,
sh,
Expand All @@ -32,6 +34,7 @@ class _RasterizeGaussians(torch.autograd.Function):
@staticmethod
def forward(
ctx,
instance,
means3D,
means2D,
sh,
Expand All @@ -46,6 +49,7 @@ def forward(

# Restructure arguments the way that the C++ lib expects them
args = (
instance,
raster_settings.bg,
means3D,
colors_precomp,
Expand All @@ -72,6 +76,7 @@ def forward(

# Keep relevant tensors for backward
ctx.raster_settings = raster_settings
ctx.instance = instance
ctx.rasterizer_state = rasterizer_state
ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh)
return color, radii
Expand All @@ -80,12 +85,14 @@ def forward(
def backward(ctx, grad_out_color, _):

# Restore necessary values from context
instance = ctx.instance
rasterizer_state = ctx.rasterizer_state
raster_settings = ctx.raster_settings
colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh = ctx.saved_tensors

# Restructure args as C++ method expects them
args = (rasterizer_state,
args = (instance,
rasterizer_state,
raster_settings.bg,
means3D,
radii,
Expand All @@ -107,6 +114,7 @@ def backward(ctx, grad_out_color, _):
grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)

grads = (
None,
grad_means3D,
grad_means2D,
grad_sh,
Expand Down Expand Up @@ -134,33 +142,32 @@ class GaussianRasterizationSettings(NamedTuple):
campos : torch.Tensor
prefiltered : bool

def createRasterizerState():
return _C.create_rasterizer_state()

def deleteRasterizerState(state):
return _C.delete_rasterize_state(state)

class GaussianRasterizer(nn.Module):
def __init__(self, raster_settings, rasterizer_state):
def __init__(self):
super().__init__()
self.raster_settings = raster_settings
self.rasterizer_state = rasterizer_state
self.instance = _C.create_rasterizer()

def __del__(self):
_C.delete_rasterizer(self.instance)

def markVisible(self, positions):
def createRasterizerState(self):
return _C.create_rasterizer_state(self.instance)

def deleteRasterizerState(self, state):
_C.delete_rasterizer_state(self.instance, state)

def markVisible(self, raster_settings, positions):
# Mark visible points (based on frustum culling for camera) with a boolean
with torch.no_grad():
raster_settings = self.raster_settings
visible = _C.mark_visible(
self.instance,
positions,
raster_settings.viewmatrix,
raster_settings.projmatrix)

return visible

def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None):
raster_settings = self.raster_settings
rasterize_state = self.rasterizer_state

def forward(self, rasterizer_state, raster_settings, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None):
if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
raise Exception('Please provide excatly one of either SHs or precomputed colors!')

Expand All @@ -181,6 +188,7 @@ def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None

# Invoke C++/CUDA rasterization routine
return rasterize_gaussians(
self.instance,
means3D,
means2D,
shs,
Expand All @@ -190,6 +198,6 @@ def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None
rotations,
cov3D_precomp,
raster_settings,
rasterize_state
rasterizer_state
)

0 comments on commit e41a365

Please sign in to comment.