Skip to content

Commit

Permalink
Allow different tv weights for sigma/color in background
Browse files Browse the repository at this point in the history
  • Loading branch information
sxyu committed Nov 7, 2021
1 parent b3b0649 commit 89cb978
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 87 deletions.
23 changes: 16 additions & 7 deletions opt/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
group.add_argument('--lr_basis_begin_step', type=int, default=0)#4 * 12800)
group.add_argument('--lr_basis_delay_mult', type=float, default=1e-2)

group.add_argument('--rms_beta', type=float, default=0.95, help="RMSProp exponential averaging factor")

group.add_argument('--n_iters', type=int, default=20 * 12800, help='number of iters to optimize for')
group.add_argument('--print_every', type=int, default=20, help='print every')
Expand Down Expand Up @@ -138,7 +139,8 @@
group.add_argument('--tune_mode', action='store_true', default=False,
help='hypertuning mode (do not save, for speed)')

group.add_argument('--rms_beta', type=float, default=0.95)
group = parser.add_argument_group("misc experiments")
# Foreground TV
group.add_argument('--lambda_tv', type=float, default=1e-5)
group.add_argument('--tv_sparsity', type=float, default=0.01)
group.add_argument('--tv_logalpha', action='store_true', default=False,
Expand All @@ -148,11 +150,17 @@
group.add_argument('--tv_sh_sparsity', type=float, default=0.01)

group.add_argument('--lambda_l2_sh', type=float, default=0.0)#1e-4)
# End Foreground TV


# Background TV
group.add_argument('--lambda_tv_background_sigma', type=float, default=1e-4)
group.add_argument('--lambda_tv_background_color', type=float, default=1e-4)

group.add_argument('--lambda_tv_background', type=float, default=1e-3)
group.add_argument('--tv_background_sparsity', type=float, default=0.01)
# End Background TV

group.add_argument('--lambda_tv_basis', type=float, default=0.0)
group.add_argument('--lambda_tv_basis', type=float, default=0.0, help='Learned basis total variation loss')

group.add_argument('--weight_decay_sigma', type=float, default=1.0)
group.add_argument('--weight_decay_sh', type=float, default=1.0)
Expand Down Expand Up @@ -202,8 +210,8 @@
basis_reso=args.basis_reso,
basis_type=svox2.__dict__['BASIS_TYPE_' + args.basis_type.upper()],
mlp_posenc_size=4,
background_nlayers=16,
background_reso=256)#1024)
background_nlayers=32,
background_reso=512)#1024)

grid.opt.last_sample_opaque = dset.last_sample_opaque

Expand Down Expand Up @@ -448,9 +456,10 @@ def train_step():
if args.lambda_l2_sh > 0.0:
grid.inplace_l2_color_grad(grid.sh_data.grad,
scaling=args.lambda_l2_sh)
if args.lambda_tv_background > 0.0:
if args.lambda_tv_background_sigma > 0.0 or args.lambda_tv_background_color > 0.0:
grid.inplace_tv_background_grad(grid.background_cubemap.grad,
scaling=args.lambda_tv_background,
scaling=args.lambda_tv_background_color,
scaling_density=args.lambda_tv_background_sigma,
sparse_frac=args.tv_background_sparsity)
if args.lambda_tv_basis > 0.0:
tv_basis = grid.tv_basis()
Expand Down
9 changes: 8 additions & 1 deletion svox2/csrc/loss_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,11 @@ __global__ void msi_tv_grad_sparse_kernel(
const torch::PackedTensorAccessor32<float, 5, torch::RestrictPtrTraits> cubemap,
const int32_t* __restrict__ rand_cells,
float scale,
float scale_last,
size_t Q,
// Output
torch::PackedTensorAccessor32<bool, 4, torch::RestrictPtrTraits> cubemap_mask,
torch::PackedTensorAccessor32<float, 5, torch::RestrictPtrTraits> grad_cubemap) {
torch::PackedTensorAccessor32<float, 5, torch::RestrictPtrTraits> grad_cubemap) {
CUDA_GET_THREAD_ID_U64(tid, Q);
const int channel_id = tid % cubemap.size(4);
const int msi_idx = rand_cells[tid / cubemap.size(4)];
Expand All @@ -430,6 +431,10 @@ __global__ void msi_tv_grad_sparse_kernel(
const float v10 = cubemap[layer_id][face_id][u + 1][v][channel_id];
const float v_nxl = cubemap[layer_id + 1][face_id][u][v][channel_id];

if (channel_id == cubemap.size(4) - 1) {
scale = scale_last;
}

float dx = (v10 - v00);
float dy = (v01 - v00);
float dz = (v_nxl - v00);
Expand Down Expand Up @@ -641,6 +646,7 @@ void msi_tv_grad_sparse(torch::Tensor cubemap,
torch::Tensor rand_cells,
torch::Tensor mask_out,
float scale,
float scale_last,
torch::Tensor grad_cubemap) {
DEVICE_GUARD(cubemap);
CHECK_INPUT(cubemap);
Expand All @@ -662,6 +668,7 @@ void msi_tv_grad_sparse(torch::Tensor cubemap,
cubemap.packed_accessor32<float, 5, torch::RestrictPtrTraits>(),
rand_cells.data_ptr<int32_t>(),
scale / nl,
scale_last / nl,
Q,
// Output
mask_out.packed_accessor32<bool, 4, torch::RestrictPtrTraits>(),
Expand Down
159 changes: 80 additions & 79 deletions svox2/csrc/svox2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void tv_grad(Tensor, Tensor, int, int, float, bool, float, bool, float, float,
Tensor);
void tv_grad_sparse(Tensor, Tensor, Tensor, Tensor, int, int, float, bool,
float, bool, float, float, Tensor);
void msi_tv_grad_sparse(Tensor, Tensor, Tensor, float, Tensor);
void msi_tv_grad_sparse(Tensor, Tensor, Tensor, float, float, Tensor);

// Optim
void rmsprop_step(Tensor, Tensor, Tensor, Tensor, float, float, float, float,
Expand All @@ -50,84 +50,85 @@ void sgd_step(Tensor, Tensor, Tensor, float, float);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#define _REG_FUNC(funname) m.def(#funname, &funname)
_REG_FUNC(sample_grid);
_REG_FUNC(sample_grid_backward);
_REG_FUNC(volume_render_cuvol);
_REG_FUNC(volume_render_cuvol_backward);
_REG_FUNC(volume_render_cuvol_fused);
// _REG_FUNC(volume_render_cuvol_image);
// _REG_FUNC(volume_render_cuvol_image_backward);

// Loss
_REG_FUNC(tv);
_REG_FUNC(tv_grad);
_REG_FUNC(tv_grad_sparse);
_REG_FUNC(msi_tv_grad_sparse);

// Misc
_REG_FUNC(dilate);
_REG_FUNC(accel_dist_prop);
_REG_FUNC(grid_weight_render);
_REG_FUNC(sample_cubemap);

// Optimizer
_REG_FUNC(rmsprop_step);
_REG_FUNC(sgd_step);
_REG_FUNC(sample_grid);
_REG_FUNC(sample_grid_backward);
_REG_FUNC(volume_render_cuvol);
_REG_FUNC(volume_render_cuvol_backward);
_REG_FUNC(volume_render_cuvol_fused);
// _REG_FUNC(volume_render_cuvol_image);
// _REG_FUNC(volume_render_cuvol_image_backward);

// Loss
_REG_FUNC(tv);
_REG_FUNC(tv_grad);
_REG_FUNC(tv_grad_sparse);
_REG_FUNC(msi_tv_grad_sparse);

// Misc
_REG_FUNC(dilate);
_REG_FUNC(accel_dist_prop);
_REG_FUNC(grid_weight_render);
_REG_FUNC(sample_cubemap);

// Optimizer
_REG_FUNC(rmsprop_step);
_REG_FUNC(sgd_step);
#undef _REG_FUNC

py::class_<SparseGridSpec>(m, "SparseGridSpec")
.def(py::init<>())
.def_readwrite("density_data", &SparseGridSpec::density_data)
.def_readwrite("sh_data", &SparseGridSpec::sh_data)
.def_readwrite("links", &SparseGridSpec::links)
.def_readwrite("_offset", &SparseGridSpec::_offset)
.def_readwrite("_scaling", &SparseGridSpec::_scaling)
.def_readwrite("basis_dim", &SparseGridSpec::basis_dim)
.def_readwrite("basis_type", &SparseGridSpec::basis_type)
.def_readwrite("basis_data", &SparseGridSpec::basis_data)
.def_readwrite("background_cubemap", &SparseGridSpec::background_cubemap);

py::class_<CameraSpec>(m, "CameraSpec")
.def(py::init<>())
.def_readwrite("c2w", &CameraSpec::c2w)
.def_readwrite("fx", &CameraSpec::fx)
.def_readwrite("fy", &CameraSpec::fy)
.def_readwrite("cx", &CameraSpec::cx)
.def_readwrite("cy", &CameraSpec::cy)
.def_readwrite("width", &CameraSpec::width)
.def_readwrite("height", &CameraSpec::height)
.def_readwrite("ndc_coeffx", &CameraSpec::ndc_coeffx)
.def_readwrite("ndc_coeffy", &CameraSpec::ndc_coeffy);

py::class_<RaysSpec>(m, "RaysSpec")
.def(py::init<>())
.def_readwrite("origins", &RaysSpec::origins)
.def_readwrite("dirs", &RaysSpec::dirs);

py::class_<RenderOptions>(m, "RenderOptions")
.def(py::init<>())
.def_readwrite("background_brightness",
&RenderOptions::background_brightness)
// .def_readwrite("step_epsilon", &RenderOptions::step_epsilon)
.def_readwrite("step_size", &RenderOptions::step_size)
.def_readwrite("sigma_thresh", &RenderOptions::sigma_thresh)
.def_readwrite("stop_thresh", &RenderOptions::stop_thresh)
.def_readwrite("last_sample_opaque", &RenderOptions::last_sample_opaque)
.def_readwrite("background_msi_scale",
&RenderOptions::background_msi_scale);
// .def_readwrite("randomize", &RenderOptions::randomize)
// .def_readwrite("_m1", &RenderOptions::_m1)
// .def_readwrite("_m2", &RenderOptions::_m2)
// .def_readwrite("_m3", &RenderOptions::_m3);

py::class_<GridOutputGrads>(m, "GridOutputGrads")
.def(py::init<>())
.def_readwrite("grad_density_out", &GridOutputGrads::grad_density_out)
.def_readwrite("grad_sh_out", &GridOutputGrads::grad_sh_out)
.def_readwrite("grad_basis_out", &GridOutputGrads::grad_basis_out)
.def_readwrite("grad_background_out",
&GridOutputGrads::grad_background_out)
.def_readwrite("mask_out", &GridOutputGrads::mask_out)
.def_readwrite("mask_background_out",
&GridOutputGrads::mask_background_out);
py::class_<SparseGridSpec>(m, "SparseGridSpec")
.def(py::init<>())
.def_readwrite("density_data", &SparseGridSpec::density_data)
.def_readwrite("sh_data", &SparseGridSpec::sh_data)
.def_readwrite("links", &SparseGridSpec::links)
.def_readwrite("_offset", &SparseGridSpec::_offset)
.def_readwrite("_scaling", &SparseGridSpec::_scaling)
.def_readwrite("basis_dim", &SparseGridSpec::basis_dim)
.def_readwrite("basis_type", &SparseGridSpec::basis_type)
.def_readwrite("basis_data", &SparseGridSpec::basis_data)
.def_readwrite("background_cubemap",
&SparseGridSpec::background_cubemap);

py::class_<CameraSpec>(m, "CameraSpec")
.def(py::init<>())
.def_readwrite("c2w", &CameraSpec::c2w)
.def_readwrite("fx", &CameraSpec::fx)
.def_readwrite("fy", &CameraSpec::fy)
.def_readwrite("cx", &CameraSpec::cx)
.def_readwrite("cy", &CameraSpec::cy)
.def_readwrite("width", &CameraSpec::width)
.def_readwrite("height", &CameraSpec::height)
.def_readwrite("ndc_coeffx", &CameraSpec::ndc_coeffx)
.def_readwrite("ndc_coeffy", &CameraSpec::ndc_coeffy);

py::class_<RaysSpec>(m, "RaysSpec")
.def(py::init<>())
.def_readwrite("origins", &RaysSpec::origins)
.def_readwrite("dirs", &RaysSpec::dirs);

py::class_<RenderOptions>(m, "RenderOptions")
.def(py::init<>())
.def_readwrite("background_brightness",
&RenderOptions::background_brightness)
// .def_readwrite("step_epsilon", &RenderOptions::step_epsilon)
.def_readwrite("step_size", &RenderOptions::step_size)
.def_readwrite("sigma_thresh", &RenderOptions::sigma_thresh)
.def_readwrite("stop_thresh", &RenderOptions::stop_thresh)
.def_readwrite("last_sample_opaque", &RenderOptions::last_sample_opaque)
.def_readwrite("background_msi_scale",
&RenderOptions::background_msi_scale);
// .def_readwrite("randomize", &RenderOptions::randomize)
// .def_readwrite("_m1", &RenderOptions::_m1)
// .def_readwrite("_m2", &RenderOptions::_m2)
// .def_readwrite("_m3", &RenderOptions::_m3);

py::class_<GridOutputGrads>(m, "GridOutputGrads")
.def(py::init<>())
.def_readwrite("grad_density_out", &GridOutputGrads::grad_density_out)
.def_readwrite("grad_sh_out", &GridOutputGrads::grad_sh_out)
.def_readwrite("grad_basis_out", &GridOutputGrads::grad_basis_out)
.def_readwrite("grad_background_out",
&GridOutputGrads::grad_background_out)
.def_readwrite("mask_out", &GridOutputGrads::mask_out)
.def_readwrite("mask_background_out",
&GridOutputGrads::mask_background_out);
}
4 changes: 4 additions & 0 deletions svox2/svox2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,6 +1464,7 @@ def inplace_tv_background_grad(
self,
grad: torch.Tensor,
scaling: float = 1.0,
scaling_density: Optional[float] = None,
sparse_frac: float = 0.01
):
"""
Expand All @@ -1476,11 +1477,14 @@ def inplace_tv_background_grad(

rand_cells_bg = self._get_rand_cells_background(sparse_frac)
indexer = self._get_sparse_background_grad_indexer()
if scaling_density is None:
scaling_density = scaling
_C.msi_tv_grad_sparse(
self.background_cubemap,
rand_cells_bg,
indexer,
scaling,
scaling_density,
grad)

def inplace_tv_basis_grad(
Expand Down

0 comments on commit 89cb978

Please sign in to comment.