diff --git a/opt/opt.py b/opt/opt.py index 9930fd17..ef108ad8 100644 --- a/opt/opt.py +++ b/opt/opt.py @@ -106,6 +106,7 @@ group.add_argument('--lr_basis_begin_step', type=int, default=0)#4 * 12800) group.add_argument('--lr_basis_delay_mult', type=float, default=1e-2) +group.add_argument('--rms_beta', type=float, default=0.95, help="RMSProp exponential averaging factor") group.add_argument('--n_iters', type=int, default=20 * 12800, help='number of iters to optimize for') group.add_argument('--print_every', type=int, default=20, help='print every') @@ -138,7 +139,8 @@ group.add_argument('--tune_mode', action='store_true', default=False, help='hypertuning mode (do not save, for speed)') -group.add_argument('--rms_beta', type=float, default=0.95) +group = parser.add_argument_group("misc experiments") +# Foreground TV group.add_argument('--lambda_tv', type=float, default=1e-5) group.add_argument('--tv_sparsity', type=float, default=0.01) group.add_argument('--tv_logalpha', action='store_true', default=False, @@ -148,11 +150,17 @@ group.add_argument('--tv_sh_sparsity', type=float, default=0.01) group.add_argument('--lambda_l2_sh', type=float, default=0.0)#1e-4) +# End Foreground TV + + +# Background TV +group.add_argument('--lambda_tv_background_sigma', type=float, default=1e-4) +group.add_argument('--lambda_tv_background_color', type=float, default=1e-4) -group.add_argument('--lambda_tv_background', type=float, default=1e-3) group.add_argument('--tv_background_sparsity', type=float, default=0.01) +# End Background TV -group.add_argument('--lambda_tv_basis', type=float, default=0.0) +group.add_argument('--lambda_tv_basis', type=float, default=0.0, help='Learned basis total variation loss') group.add_argument('--weight_decay_sigma', type=float, default=1.0) group.add_argument('--weight_decay_sh', type=float, default=1.0) @@ -202,8 +210,8 @@ basis_reso=args.basis_reso, basis_type=svox2.__dict__['BASIS_TYPE_' + args.basis_type.upper()], mlp_posenc_size=4, - background_nlayers=16, - background_reso=256)#1024) + background_nlayers=32, + background_reso=512)#1024) grid.opt.last_sample_opaque = dset.last_sample_opaque @@ -448,9 +456,10 @@ def train_step(): if args.lambda_l2_sh > 0.0: grid.inplace_l2_color_grad(grid.sh_data.grad, scaling=args.lambda_l2_sh) - if args.lambda_tv_background > 0.0: + if args.lambda_tv_background_sigma > 0.0 or args.lambda_tv_background_color > 0.0: grid.inplace_tv_background_grad(grid.background_cubemap.grad, - scaling=args.lambda_tv_background, + scaling=args.lambda_tv_background_color, + scaling_density=args.lambda_tv_background_sigma, sparse_frac=args.tv_background_sparsity) if args.lambda_tv_basis > 0.0: tv_basis = grid.tv_basis() diff --git a/svox2/csrc/loss_kernel.cu b/svox2/csrc/loss_kernel.cu index 403f0c0f..5d0ac234 100644 --- a/svox2/csrc/loss_kernel.cu +++ b/svox2/csrc/loss_kernel.cu @@ -410,10 +410,11 @@ __global__ void msi_tv_grad_sparse_kernel( const torch::PackedTensorAccessor32 cubemap, const int32_t* __restrict__ rand_cells, float scale, + float scale_last, size_t Q, // Output torch::PackedTensorAccessor32 cubemap_mask, - torch::PackedTensorAccessor32 grad_cubemap) { + torch::PackedTensorAccessor32 grad_cubemap) { CUDA_GET_THREAD_ID_U64(tid, Q); const int channel_id = tid % cubemap.size(4); const int msi_idx = rand_cells[tid / cubemap.size(4)]; @@ -430,6 +431,10 @@ __global__ void msi_tv_grad_sparse_kernel( const float v10 = cubemap[layer_id][face_id][u + 1][v][channel_id]; const float v_nxl = cubemap[layer_id + 1][face_id][u][v][channel_id]; + if (channel_id == cubemap.size(4) - 1) { + scale = scale_last; + } + float dx = (v10 - v00); float dy = (v01 - v00); float dz = (v_nxl - v00); @@ -641,6 +646,7 @@ void msi_tv_grad_sparse(torch::Tensor cubemap, torch::Tensor rand_cells, torch::Tensor mask_out, float scale, + float scale_last, torch::Tensor grad_cubemap) { DEVICE_GUARD(cubemap); CHECK_INPUT(cubemap); @@ -662,6 +668,7 @@ void msi_tv_grad_sparse(torch::Tensor cubemap, cubemap.packed_accessor32(), rand_cells.data_ptr(), scale / nl, + scale_last / nl, Q, // Output mask_out.packed_accessor32(), diff --git a/svox2/csrc/svox2.cpp b/svox2/csrc/svox2.cpp index e559c007..ad822e87 100644 --- a/svox2/csrc/svox2.cpp +++ b/svox2/csrc/svox2.cpp @@ -41,7 +41,7 @@ void tv_grad(Tensor, Tensor, int, int, float, bool, float, bool, float, float, Tensor); void tv_grad_sparse(Tensor, Tensor, Tensor, Tensor, int, int, float, bool, float, bool, float, float, Tensor); -void msi_tv_grad_sparse(Tensor, Tensor, Tensor, float, Tensor); +void msi_tv_grad_sparse(Tensor, Tensor, Tensor, float, float, Tensor); // Optim void rmsprop_step(Tensor, Tensor, Tensor, Tensor, float, float, float, float, @@ -50,84 +50,85 @@ void sgd_step(Tensor, Tensor, Tensor, float, float); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { #define _REG_FUNC(funname) m.def(#funname, &funname) - _REG_FUNC(sample_grid); - _REG_FUNC(sample_grid_backward); - _REG_FUNC(volume_render_cuvol); - _REG_FUNC(volume_render_cuvol_backward); - _REG_FUNC(volume_render_cuvol_fused); - // _REG_FUNC(volume_render_cuvol_image); - // _REG_FUNC(volume_render_cuvol_image_backward); - - // Loss - _REG_FUNC(tv); - _REG_FUNC(tv_grad); - _REG_FUNC(tv_grad_sparse); - _REG_FUNC(msi_tv_grad_sparse); - - // Misc - _REG_FUNC(dilate); - _REG_FUNC(accel_dist_prop); - _REG_FUNC(grid_weight_render); - _REG_FUNC(sample_cubemap); - - // Optimizer - _REG_FUNC(rmsprop_step); - _REG_FUNC(sgd_step); + _REG_FUNC(sample_grid); + _REG_FUNC(sample_grid_backward); + _REG_FUNC(volume_render_cuvol); + _REG_FUNC(volume_render_cuvol_backward); + _REG_FUNC(volume_render_cuvol_fused); + // _REG_FUNC(volume_render_cuvol_image); + // _REG_FUNC(volume_render_cuvol_image_backward); + + // Loss + _REG_FUNC(tv); + _REG_FUNC(tv_grad); + _REG_FUNC(tv_grad_sparse); + _REG_FUNC(msi_tv_grad_sparse); + + // Misc + _REG_FUNC(dilate); + _REG_FUNC(accel_dist_prop); + _REG_FUNC(grid_weight_render); + _REG_FUNC(sample_cubemap); + + // Optimizer + _REG_FUNC(rmsprop_step); + _REG_FUNC(sgd_step); #undef _REG_FUNC - py::class_(m, "SparseGridSpec") - .def(py::init<>()) - .def_readwrite("density_data", &SparseGridSpec::density_data) - .def_readwrite("sh_data", &SparseGridSpec::sh_data) - .def_readwrite("links", &SparseGridSpec::links) - .def_readwrite("_offset", &SparseGridSpec::_offset) - .def_readwrite("_scaling", &SparseGridSpec::_scaling) - .def_readwrite("basis_dim", &SparseGridSpec::basis_dim) - .def_readwrite("basis_type", &SparseGridSpec::basis_type) - .def_readwrite("basis_data", &SparseGridSpec::basis_data) - .def_readwrite("background_cubemap", &SparseGridSpec::background_cubemap); - - py::class_(m, "CameraSpec") - .def(py::init<>()) - .def_readwrite("c2w", &CameraSpec::c2w) - .def_readwrite("fx", &CameraSpec::fx) - .def_readwrite("fy", &CameraSpec::fy) - .def_readwrite("cx", &CameraSpec::cx) - .def_readwrite("cy", &CameraSpec::cy) - .def_readwrite("width", &CameraSpec::width) - .def_readwrite("height", &CameraSpec::height) - .def_readwrite("ndc_coeffx", &CameraSpec::ndc_coeffx) - .def_readwrite("ndc_coeffy", &CameraSpec::ndc_coeffy); - - py::class_(m, "RaysSpec") - .def(py::init<>()) - .def_readwrite("origins", &RaysSpec::origins) - .def_readwrite("dirs", &RaysSpec::dirs); - - py::class_(m, "RenderOptions") - .def(py::init<>()) - .def_readwrite("background_brightness", - &RenderOptions::background_brightness) - // .def_readwrite("step_epsilon", &RenderOptions::step_epsilon) - .def_readwrite("step_size", &RenderOptions::step_size) - .def_readwrite("sigma_thresh", &RenderOptions::sigma_thresh) - .def_readwrite("stop_thresh", &RenderOptions::stop_thresh) - .def_readwrite("last_sample_opaque", &RenderOptions::last_sample_opaque) - .def_readwrite("background_msi_scale", - &RenderOptions::background_msi_scale); - // .def_readwrite("randomize", &RenderOptions::randomize) - // .def_readwrite("_m1", &RenderOptions::_m1) - // .def_readwrite("_m2", &RenderOptions::_m2) - // .def_readwrite("_m3", &RenderOptions::_m3); - - py::class_(m, "GridOutputGrads") - .def(py::init<>()) - .def_readwrite("grad_density_out", &GridOutputGrads::grad_density_out) - .def_readwrite("grad_sh_out", &GridOutputGrads::grad_sh_out) - .def_readwrite("grad_basis_out", &GridOutputGrads::grad_basis_out) - .def_readwrite("grad_background_out", - &GridOutputGrads::grad_background_out) - .def_readwrite("mask_out", &GridOutputGrads::mask_out) - .def_readwrite("mask_background_out", - &GridOutputGrads::mask_background_out); + py::class_(m, "SparseGridSpec") + .def(py::init<>()) + .def_readwrite("density_data", &SparseGridSpec::density_data) + .def_readwrite("sh_data", &SparseGridSpec::sh_data) + .def_readwrite("links", &SparseGridSpec::links) + .def_readwrite("_offset", &SparseGridSpec::_offset) + .def_readwrite("_scaling", &SparseGridSpec::_scaling) + .def_readwrite("basis_dim", &SparseGridSpec::basis_dim) + .def_readwrite("basis_type", &SparseGridSpec::basis_type) + .def_readwrite("basis_data", &SparseGridSpec::basis_data) + .def_readwrite("background_cubemap", + &SparseGridSpec::background_cubemap); + + py::class_(m, "CameraSpec") + .def(py::init<>()) + .def_readwrite("c2w", &CameraSpec::c2w) + .def_readwrite("fx", &CameraSpec::fx) + .def_readwrite("fy", &CameraSpec::fy) + .def_readwrite("cx", &CameraSpec::cx) + .def_readwrite("cy", &CameraSpec::cy) + .def_readwrite("width", &CameraSpec::width) + .def_readwrite("height", &CameraSpec::height) + .def_readwrite("ndc_coeffx", &CameraSpec::ndc_coeffx) + .def_readwrite("ndc_coeffy", &CameraSpec::ndc_coeffy); + + py::class_(m, "RaysSpec") + .def(py::init<>()) + .def_readwrite("origins", &RaysSpec::origins) + .def_readwrite("dirs", &RaysSpec::dirs); + + py::class_(m, "RenderOptions") + .def(py::init<>()) + .def_readwrite("background_brightness", + &RenderOptions::background_brightness) + // .def_readwrite("step_epsilon", &RenderOptions::step_epsilon) + .def_readwrite("step_size", &RenderOptions::step_size) + .def_readwrite("sigma_thresh", &RenderOptions::sigma_thresh) + .def_readwrite("stop_thresh", &RenderOptions::stop_thresh) + .def_readwrite("last_sample_opaque", &RenderOptions::last_sample_opaque) + .def_readwrite("background_msi_scale", + &RenderOptions::background_msi_scale); + // .def_readwrite("randomize", &RenderOptions::randomize) + // .def_readwrite("_m1", &RenderOptions::_m1) + // .def_readwrite("_m2", &RenderOptions::_m2) + // .def_readwrite("_m3", &RenderOptions::_m3); + + py::class_(m, "GridOutputGrads") + .def(py::init<>()) + .def_readwrite("grad_density_out", &GridOutputGrads::grad_density_out) + .def_readwrite("grad_sh_out", &GridOutputGrads::grad_sh_out) + .def_readwrite("grad_basis_out", &GridOutputGrads::grad_basis_out) + .def_readwrite("grad_background_out", + &GridOutputGrads::grad_background_out) + .def_readwrite("mask_out", &GridOutputGrads::mask_out) + .def_readwrite("mask_background_out", + &GridOutputGrads::mask_background_out); } diff --git a/svox2/svox2.py b/svox2/svox2.py index 8268f532..a67f71c2 100644 --- a/svox2/svox2.py +++ b/svox2/svox2.py @@ -1464,6 +1464,7 @@ def inplace_tv_background_grad( self, grad: torch.Tensor, scaling: float = 1.0, + scaling_density: Optional[float] = None, sparse_frac: float = 0.01 ): """ @@ -1476,11 +1477,14 @@ def inplace_tv_background_grad( rand_cells_bg = self._get_rand_cells_background(sparse_frac) indexer = self._get_sparse_background_grad_indexer() + if scaling_density is None: + scaling_density = scaling _C.msi_tv_grad_sparse( self.background_cubemap, rand_cells_bg, indexer, scaling, + scaling_density, grad) def inplace_tv_basis_grad(