diff --git a/manual_install.sh b/manual_install.sh index c33a92b3..2ab6d2db 100644 --- a/manual_install.sh +++ b/manual_install.sh @@ -1,4 +1,5 @@ cp svox2/svox2.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/svox2.py cp svox2/utils.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/utils.py cp svox2/version.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/version.py +cp svox2/defs.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/defs.py cp svox2/__init__.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/__init__.py diff --git a/opt/opt.py b/opt/opt.py index 60c96f16..161e490e 100644 --- a/opt/opt.py +++ b/opt/opt.py @@ -5,6 +5,7 @@ # or use launching script: sh launch.sh import torch import torch.cuda +import torch.optim import torch.nn.functional as F import svox2 import json @@ -64,12 +65,12 @@ # TODO: make the lr higher near the end group.add_argument('--sigma_optim', choices=['sgd', 'rmsprop'], default='rmsprop', help="Density optimizer") -group.add_argument('--lr_sigma', type=float, default= - 3e1, +group.add_argument('--lr_sigma', type=float, default=3e1, help='SGD/rmsprop lr for sigma') group.add_argument('--lr_sigma_final', type=float, default=5e-2) group.add_argument('--lr_sigma_decay_steps', type=int, default=250000) -group.add_argument('--lr_sigma_delay_steps', type=int, default=15000, +group.add_argument('--lr_sigma_delay_steps', type=int, default= + 15000, help="Reverse cosine steps (0 means disable)") group.add_argument('--lr_sigma_delay_mult', type=float, default=1e-2) @@ -90,11 +91,11 @@ group.add_argument('--basis_optim', choices=['sgd', 'rmsprop'], default='rmsprop', help="Learned basis optimizer") group.add_argument('--lr_basis', type=float, default=#2e6, - 1e-6, + 1e-3, help='SGD/rmsprop lr for SH') group.add_argument('--lr_basis_final', type=float, default= - 1e-6 + 1e-4 ) group.add_argument('--lr_basis_decay_steps', type=int, default=250000) group.add_argument('--lr_basis_delay_steps', type=int, default=0,#15000, @@ -140,16 +141,20 @@ group.add_argument('--tv_logalpha', action='store_true', default=False, help='Use log(1-exp(-delta * sigma)) as in neural volumes') -group.add_argument('--lambda_tv_sh', type=float, default=1e-2) +group.add_argument('--lambda_tv_sh', type=float, default=0.0) # FIXME group.add_argument('--tv_sh_sparsity', type=float, default=0.01) +group.add_argument('--lambda_l2_sh', type=float, default=1e-4) # FIXME + group.add_argument('--lambda_tv_basis', type=float, default=0.0) group.add_argument('--weight_decay_sigma', type=float, default=1.0) group.add_argument('--weight_decay_sh', type=float, default=1.0) group.add_argument('--lr_decay', action='store_true', default=True) -group.add_argument('--use_learned_basis', action='store_true', default=False) +group.add_argument('--basis_type', + choices=['sh', '3d_texture', 'mlp'], + default='sh') args = parser.parse_args() assert args.lr_sigma_final <= args.lr_sigma, "lr_sigma must be >= lr_sigma_final" @@ -189,7 +194,8 @@ use_z_order=True, device=device, basis_reso=args.basis_reso, - use_learned_basis=args.use_learned_basis) + basis_type=svox2.__dict__['BASIS_TYPE_' + args.basis_type.upper()], + mlp_posenc_size=4) grid.opt.last_sample_opaque = dset.last_sample_opaque @@ -205,12 +211,21 @@ # # den[:, :, 0] = 1e9 # grid.density_data.data = den.view(osh) -if grid.use_learned_basis: - grid.reinit_learned_bases(init_type='sh') -# grid.reinit_learned_bases(init_type='fourier') -# grid.reinit_learned_bases(init_type='sg', upper_hemi=True) -# grid.basis_data.data.normal_(mean=0.0, std=0.01) -# grid.basis_data.data += 0.28209479177387814 +optim_basis_mlp = None + +if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: + # grid.reinit_learned_bases(init_type='sh') + # grid.reinit_learned_bases(init_type='fourier') + # grid.reinit_learned_bases(init_type='sg', upper_hemi=True) + grid.basis_data.data.normal_(mean=0.28209479177387814, std=0.001) + +elif grid.basis_type == svox2.BASIS_TYPE_MLP: + # MLP! + optim_basis_mlp = torch.optim.Adam( + grid.basis_mlp.parameters(), + lr=args.lr_basis + ) + grid.requires_grad_(True) step_size = 0.5 # 0.5 of a voxel! @@ -303,14 +318,18 @@ def eval_step(): stats_test['psnr'] += psnr n_images_gen += 1 - if grid.use_learned_basis: - # Add spherical map visualization + if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE or \ + grid.basis_type == svox2.BASIS_TYPE_MLP: + # Add spherical map visualization EQ_RESO = 256 eq_dirs = generate_dirs_equirect(EQ_RESO * 2, EQ_RESO) eq_dirs = torch.from_numpy(eq_dirs).to(device=device).view(-1, 3) - sphfuncs = grid._eval_learned_bases(eq_dirs).view(EQ_RESO, EQ_RESO*2, -1) - sphfuncs = sphfuncs.permute([2, 0, 1]).cpu().numpy() + if grid.basis_type == svox2.BASIS_TYPE_MLP: + sphfuncs = grid._eval_basis_mlp(eq_dirs) + else: + sphfuncs = grid._eval_learned_bases(eq_dirs) + sphfuncs = sphfuncs.view(EQ_RESO, EQ_RESO*2, -1).permute([2, 0, 1]).cpu().numpy() stats = [(sphfunc.min(), sphfunc.mean(), sphfunc.max()) for sphfunc in sphfuncs] @@ -404,6 +423,9 @@ def train_step(): grid.inplace_tv_color_grad(grid.sh_data.grad, scaling=args.lambda_tv_sh, sparse_frac=args.tv_sh_sparsity) + if args.lambda_l2_sh > 0.0: + grid.inplace_l2_color_grad(grid.sh_data.grad, + scaling=args.lambda_l2_sh) if args.lambda_tv_basis > 0.0: tv_basis = grid.tv_basis() loss_tv_basis = tv_basis * args.lambda_tv_basis @@ -414,8 +436,12 @@ def train_step(): # Manual SGD/rmsprop step grid.optim_density_step(lr_sigma, beta=args.rms_beta, optim=args.sigma_optim) grid.optim_sh_step(lr_sh, beta=args.rms_beta, optim=args.sh_optim) - if grid.use_learned_basis and gstep_id >= args.lr_basis_begin_step: - grid.optim_basis_step(lr_basis, beta=args.rms_beta, optim=args.basis_optim) + if gstep_id >= args.lr_basis_begin_step: + if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: + grid.optim_basis_step(lr_basis, beta=args.rms_beta, optim=args.basis_optim) + elif grid.basis_type == svox2.BASIS_TYPE_MLP: + optim_basis_mlp.step() + optim_basis_mlp.zero_grad() train_step() gc.collect() diff --git a/opt/render_imgs.py b/opt/render_imgs.py index e16950dd..47a49c40 100644 --- a/opt/render_imgs.py +++ b/opt/render_imgs.py @@ -56,13 +56,13 @@ img_eval_interval = max(n_images // args.n_eval, 1) avg_psnr = 0.0 n_images_gen = 0 + cam = svox2.Camera(torch.tensor(0), dset.intrins.fx, dset.intrins.fy, + dset.intrins.cx, dset.intrins.cy, + dset.w, dset.h, + ndc_coeffs=dset.ndc_coeffs) + c2ws = dset.render_c2w.to(device=device) if args.render_path else dset.c2w.to(device=device) for img_id in tqdm(range(0, n_images, img_eval_interval)): - c2w = dset.render_c2w[img_id] if args.render_path else dset.c2w[img_id] - c2w = c2w.to(device=device) - cam = svox2.Camera(c2w, dset.intrins.fx, dset.intrins.fy, - dset.intrins.cx, dset.intrins.cy, - dset.w, dset.h, - ndc_coeffs=dset.ndc_coeffs) + cam.c2w = c2ws[img_id] im = grid.volume_render_image(cam, use_kernel=True) im.clamp_(0.0, 1.0) if not args.render_path: diff --git a/opt/tasks/sanity.json b/opt/tasks/sanity.json index 0701313e..34144eed 100644 --- a/opt/tasks/sanity.json +++ b/opt/tasks/sanity.json @@ -1,17 +1,19 @@ { "data_root": "/home/sxyu/data/nerf_synthetic/lego", - "train_root": "/home/sxyu/proj/svox2/opt/ckpt_tune/fast_256_sweepup_and_lrcdec_10epoch", + "train_root": "/home/sxyu/proj/svox2/opt/ckpt_tune/mlp_sweep", "variables": { - "lr_sh_final": "loglin(1e-5, 5e-2, 10)", - "upsamp_every": [3, 4, 6] + "lr_sh": "loglin(5e-3, 5e-2, 4)", + "lr_basis": "loglin(1e-6, 1e-3, 8)", + "upsamp_every": [25600, 38400, 51200], + "lambda_tv_sh": "loglin(1e-5, 1e-2, 6)" }, "tasks": [{ - "train_dir": "up{upsamp_every}_lrc5e-2-d{lr_sh_final:01.7f}", + "train_dir": "lrc{lr_sh}_lrb{lr_basis}_tvc{lambda_tv_sh}_ups{upsamp_every}", "flags": [ - "--n_epochs", "10", - "--init_reso", "256", "--upsamp_every", "{upsamp_every}", - "--lr_sh_final", "{lr_sh_final}" + "--lr_sh", "{lr_sh}", + "--lr_basis", "{lr_basis}", + "--upsamp_every", "{upsamp_every}" ] }] } diff --git a/opt/util/ff_dataset.py b/opt/util/ff_dataset.py index cd542656..c7751dce 100644 --- a/opt/util/ff_dataset.py +++ b/opt/util/ff_dataset.py @@ -29,29 +29,7 @@ from .load_llff import load_llff_data from typing import Union - -def convert_to_ndc(origins, directions, ndc_coeffs, near: float = 1.0): - """Convert a set of rays to NDC coordinates.""" - # Shift ray origins to near plane, not sure if needed - t = (near - origins[Ellipsis, 2]) / directions[Ellipsis, 2] - origins = origins + t[Ellipsis, None] * directions - - dx, dy, dz = directions.unbind(-1) - ox, oy, oz = origins.unbind(-1) - - # Projection - o0 = ndc_coeffs[0] * (ox / oz) - o1 = ndc_coeffs[1] * (oy / oz) - o2 = 1 - 2 * near / oz - - d0 = ndc_coeffs[0] * (dx / dz - ox / oz) - d1 = ndc_coeffs[1] * (dy / dz - oy / oz) - d2 = 2 * near / oz; - - origins = torch.stack([o0, o1, o2], -1) - directions = torch.stack([d0, d1, d2], -1) - return origins, directions - +from svox2.utils import convert_to_ndc class LLFFDataset(Dataset): def __init__( diff --git a/svox2/__init__.py b/svox2/__init__.py index a95e5b5a..e65c1c52 100644 --- a/svox2/__init__.py +++ b/svox2/__init__.py @@ -1,2 +1,3 @@ +from .defs import * from .svox2 import SparseGrid, Camera, Rays, RenderOptions from .version import __version__ diff --git a/svox2/csrc/include/data_spec.hpp b/svox2/csrc/include/data_spec.hpp index ef57443b..29c20b74 100644 --- a/svox2/csrc/include/data_spec.hpp +++ b/svox2/csrc/include/data_spec.hpp @@ -5,6 +5,16 @@ using torch::Tensor; +enum BasisType { + // For svox 1 compatibility + // BASIS_TYPE_RGBA = 0 + BASIS_TYPE_SH = 1, + // BASIS_TYPE_SG = 2 + // BASIS_TYPE_ASG = 3 + BASIS_TYPE_3D_TEXTURE = 4, + BASIS_TYPE_MLP = 255, +}; + struct SparseGridSpec { Tensor density_data; Tensor sh_data; @@ -13,7 +23,7 @@ struct SparseGridSpec { Tensor _scaling; int basis_dim; - bool use_learned_basis; + uint8_t basis_type; Tensor basis_data; inline void check() { @@ -32,7 +42,6 @@ struct SparseGridSpec { TORCH_CHECK(density_data.ndimension() == 2); TORCH_CHECK(sh_data.ndimension() == 2); TORCH_CHECK(links.ndimension() == 3); - TORCH_CHECK(basis_data.ndimension() == 4); } }; diff --git a/svox2/csrc/include/data_spec_packed.cuh b/svox2/csrc/include/data_spec_packed.cuh index 897ff9e2..dcbfa983 100644 --- a/svox2/csrc/include/data_spec_packed.cuh +++ b/svox2/csrc/include/data_spec_packed.cuh @@ -13,7 +13,7 @@ struct PackedSparseGridSpec { density_data(spec.density_data.data_ptr()), sh_data(spec.sh_data.data_ptr()), links(spec.links.data_ptr()), - use_learned_basis(spec.use_learned_basis), + basis_type(spec.basis_type), basis_data(spec.basis_data.data_ptr()), size{(int)spec.links.size(0), (int)spec.links.size(1), @@ -33,7 +33,7 @@ struct PackedSparseGridSpec { float* __restrict__ density_data; float* __restrict__ sh_data; const int32_t* __restrict__ links; - const bool use_learned_basis; + const uint8_t basis_type; float* __restrict__ basis_data; const int size[3], stride_x; diff --git a/svox2/csrc/include/render_util.cuh b/svox2/csrc/include/render_util.cuh index ef9df833..aa1e98a4 100644 --- a/svox2/csrc/include/render_util.cuh +++ b/svox2/csrc/include/render_util.cuh @@ -35,8 +35,10 @@ __device__ __inline__ float trilerp_one( __device__ __inline__ float compute_skip_dist( SingleRaySpec& __restrict__ ray, const int32_t* __restrict__ links, - int offx, int offy) { - const int32_t link_val = links[offx * ray.l[0] + offy * ray.l[1] + ray.l[2]]; + int offx, int offy, + int pos_offset = 0) { + const int32_t link_val = links[offx * (ray.l[0] + pos_offset) + offy * (ray.l[1] + pos_offset) + + (ray.l[2] + pos_offset)]; if (link_val >= -1) return 0.f; // Not worth const uint32_t dist = -link_val; @@ -49,20 +51,42 @@ __device__ __inline__ float compute_skip_dist( float tmax = 1e9f; #pragma unroll for (int i = 0; i < 3; ++i) { - int ul = ((ray.l[i] >> cell_ul_shift) << cell_ul_shift); - ul -= ray.l[i]; + int ul = (((ray.l[i] + pos_offset) >> cell_ul_shift) << cell_ul_shift); + ul -= ray.l[i] + pos_offset; const float invdir = 1.0 / ray.dir[i]; - const float t1 = (ul - ray.pos[i]) * invdir; - const float t2 = (ul + cell_side_len - ray.pos[i]) * invdir; + const float t1 = (ul - ray.pos[i] + pos_offset) * invdir; + const float t2 = (ul + cell_side_len - ray.pos[i] + pos_offset) * invdir; if (ray.dir[i] != 0.f) { tmin = max(tmin, min(t1, t2)); tmax = min(tmax, max(t1, t2)); } } + +// const uint32_t cell_ul_shift = 1 - dist; +// const uint32_t cell_br_shift = -cell_ul_shift; +// +// // AABB intersection +// // Consider caching the invdir for the ray +// float tmin = 0.f; +// float tmax = 1e9f; +// #pragma unroll +// for (int i = 0; i < 3; ++i) { +// const float invdir = 1.0 / ray.dir[i]; +// const float t1 = (cell_ul_shift - ray.pos[i] + pos_offset) * invdir; +// const float t2 = (cell_br_shift - ray.pos[i] + pos_offset) * invdir; +// if (ray.dir[i] != 0.f) { +// tmin = max(tmin, min(t1, t2)); +// tmax = min(tmax, max(t1, t2)); +// } +// } + if (tmin > 0.f) { // Somehow the origin is not in the cube - // Will happen near the lowest vertex of a cell, + // Should not happen for distance transform + + // If using geometric distances: + // will happen near the lowest vertex of a cell, // since l is always the lowest neighbor return 0.f; } @@ -257,10 +281,11 @@ __device__ __inline__ void calc_sh( __device__ __inline__ void calc_sphfunc( const PackedSparseGridSpec& grid, const int lane_id, + const int ray_id, const float* __restrict__ dir, // Pre-normalized float* __restrict__ out) { // Placeholder - if (grid.use_learned_basis) { + if (grid.basis_type == BASIS_TYPE_3D_TEXTURE) { float p[3]; int32_t l[3]; for (int j = 0; j < 3; ++j) { @@ -272,13 +297,23 @@ __device__ __inline__ void calc_sphfunc( p[j] -= static_cast(l[j]); } + if (lane_id > 0 && lane_id < grid.basis_dim) { + out[lane_id] = + fmaxf( + trilerp_one( + grid.basis_data, + grid.basis_reso, + grid.basis_dim - 1, + l, p, + lane_id - 1), + 0.f); + } + out[0] = C0; + } else if (grid.basis_type == BASIS_TYPE_MLP) { + const float* __restrict__ basis_ptr = grid.basis_data + grid.basis_dim * ray_id; if (lane_id < grid.basis_dim) { - out[lane_id] = trilerp_one(grid.basis_data, - grid.basis_reso, grid.basis_dim, - l, p, - lane_id); + out[lane_id] = _SIGMOID(basis_ptr[lane_id]); } - } else { calc_sh(grid.basis_dim, dir, out); } @@ -287,11 +322,13 @@ __device__ __inline__ void calc_sphfunc( __device__ __inline__ void calc_sphfunc_backward( const PackedSparseGridSpec& grid, const int lane_id, + const int ray_id, const float* __restrict__ dir, // Pre-normalized + const float* __restrict__ output_saved, const float* __restrict__ grad_output, - float* __restrict__ grad_out) { + float* __restrict__ grad_basis_data) { // Placeholder - if (grid.use_learned_basis) { + if (grid.basis_type == BASIS_TYPE_3D_TEXTURE) { float p[3]; int32_t l[3]; for (int j = 0; j < 3; ++j) { @@ -304,15 +341,21 @@ __device__ __inline__ void calc_sphfunc_backward( } __syncwarp((1U << grid.sh_data_dim) - 1); - if (lane_id < grid.basis_dim) { - trilerp_backward_one(grad_out, + if (lane_id > 0 && lane_id < grid.basis_dim && output_saved[lane_id] > 0.f) { + trilerp_backward_one(grad_basis_data, grid.basis_reso, - grid.basis_dim, + grid.basis_dim - 1, l, p, grad_output[lane_id], - lane_id); + lane_id - 1); + } + } else if (grid.basis_type == BASIS_TYPE_MLP) { + float* __restrict__ grad_basis_ptr = grad_basis_data + grid.basis_dim * ray_id; + if (lane_id < grid.basis_dim) { + const float sig = output_saved[lane_id]; + grad_basis_ptr[lane_id] = + sig * (1.f - sig) * grad_output[lane_id]; } - } else { // nothing needed } diff --git a/svox2/csrc/misc_kernel.cu b/svox2/csrc/misc_kernel.cu index a066e6d3..795e67ed 100644 --- a/svox2/csrc/misc_kernel.cu +++ b/svox2/csrc/misc_kernel.cu @@ -45,8 +45,61 @@ __global__ void dilate_kernel( // Probably can speed up the following functions, however they are really not // the bottleneck -// A kind of L-infty distance transform-ish thing +// ** Distance transforms // TODO: Maybe replace this with an euclidean distance transform eg PBA +// Actual L-infty distance transform; turns out this is slower than the geometric way +__global__ void accel_linf_dist_transform_kernel( + torch::PackedTensorAccessor32 grid, + int32_t* __restrict__ tmp, + int d2) { + const int d0 = d2 == 0 ? 1 : 0; + const int d1 = 3 - d0 - d2; + CUDA_GET_THREAD_ID(tid, grid.size(d0) * grid.size(d1)); + int32_t* __restrict__ tmp_ptr = tmp + tid * grid.size(d2); + int l[3]; + l[d0] = tid / grid.size(1); + l[d1] = tid % grid.size(1); + l[d2] = 0; + const int INF = 0x3f3f3f3f; + int d01_dist = INF; + int d2_dist = INF; + for (; l[d2] < grid.size(d2); ++l[d2]) { + const int val = grid[l[0]][l[1]][l[2]]; + int32_t cdist = max(- val, 0); + if (d2 == 0 && cdist) + cdist = INF; + const int32_t pdist = max(d2_dist, d01_dist); + + if (cdist < pdist) { + d01_dist = cdist; + d2_dist = 0; + } + tmp_ptr[l[d2]] = min(cdist, pdist); + ++d2_dist; + } + + l[d2] = grid.size(d2) - 1; + d01_dist = INF; + d2_dist = INF; + for (; l[d2] >= 0; --l[d2]) { + const int val = grid[l[0]][l[1]][l[2]]; + int32_t cdist = max(- val, 0); + if (d2 == 0 && cdist) + cdist = INF; + const int32_t pdist = max(d2_dist, d01_dist); + + if (cdist < pdist) { + d01_dist = cdist; + d2_dist = 0; + } + if (cdist) { + grid[l[0]][l[1]][l[2]] = -min(tmp_ptr[l[d2]], min(cdist, pdist)); + } + ++d2_dist; + } +} + +// Geometric L-infty distance transform-ish thing __global__ void accel_dist_set_kernel( const torch::PackedTensorAccessor32 grid, bool* __restrict__ tmp) { @@ -62,7 +115,6 @@ __global__ void accel_dist_set_kernel( bool* tmp_base = tmp; - if (grid[x][y][z] >= 0) { while (sz_x > 1 && sz_y > 1 && sz_z > 1) { // Propagate occupied cell throughout the temp tree parent nodes @@ -274,15 +326,15 @@ void accel_dist_prop(torch::Tensor grid) { TORCH_CHECK(!grid.is_floating_point()); TORCH_CHECK(grid.ndimension() == 3); + int sz_x = grid.size(0); + int sz_y = grid.size(1); + int sz_z = grid.size(2); + int Q = grid.size(0) * grid.size(1) * grid.size(2); const int cuda_n_threads = std::min(Q, CUDA_MAX_THREADS); const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); - int sz_x = grid.size(0); - int sz_y = grid.size(1); - int sz_z = grid.size(2); - size_t req_size = 0; while (sz_x > 1 && sz_y > 1 && sz_z > 1) { sz_x = int_div2_ceil(sz_x); @@ -303,6 +355,28 @@ void accel_dist_prop(torch::Tensor grid) { grid.packed_accessor32(), tmp); + + // int32_t* tmp; + // const int req_size = sz_x * sz_y * sz_z; + // cuda(Malloc(&tmp, req_size * sizeof(int32_t))); + // cuda(MemsetAsync(tmp, 0, req_size * sizeof(int32_t))); + // + // { + // for (int d2 = 0; d2 < 3; ++d2) { + // int d0 = d2 == 0 ? 1 : 0; + // int d1 = 3 - d0 - d2; + // int Q = grid.size(d0) * grid.size(d1); + // + // const int cuda_n_threads = std::min(Q, CUDA_MAX_THREADS); + // const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); + // + // device::accel_linf_dist_transform_kernel<<>>( + // grid.packed_accessor32(), + // tmp, + // d2); + // } + // } + cuda(Free(tmp)); CUDA_CHECK_ERRORS; } diff --git a/svox2/csrc/optim_kernel.cu b/svox2/csrc/optim_kernel.cu index 80444678..36245847 100644 --- a/svox2/csrc/optim_kernel.cu +++ b/svox2/csrc/optim_kernel.cu @@ -16,10 +16,10 @@ __inline__ __device__ void rmsprop_once( float* __restrict__ ptr_data, float* __restrict__ ptr_rms, float* __restrict__ ptr_grad, - const float beta, const float lr, const float epsilon) { + const float beta, const float lr, const float epsilon, float minval) { float rms = lerp(_SQR(*ptr_grad), *ptr_rms, beta); *ptr_rms = rms; - *ptr_data -= lr * (*ptr_grad) / (sqrtf(rms) + epsilon); + *ptr_data = fmaxf(*ptr_data - lr * (*ptr_grad) / (sqrtf(rms) + epsilon), minval); *ptr_grad = 0.f; } @@ -30,14 +30,16 @@ __global__ void rmsprop_step_kernel( torch::PackedTensorAccessor64 all_grad, float beta, float lr, - float epsilon) { + float epsilon, + float minval) { CUDA_GET_THREAD_ID(tid, all_data.size(0) * all_data.size(1)); rmsprop_once(all_data.data() + tid, all_rms.data() + tid, all_grad.data() + tid, beta, lr, - epsilon); + epsilon, + minval); } @@ -49,7 +51,8 @@ __global__ void rmsprop_mask_step_kernel( const bool* __restrict__ mask, float beta, float lr, - float epsilon) { + float epsilon, + float minval) { CUDA_GET_THREAD_ID(tid, all_data.size(0) * all_data.size(1)); if (mask[tid / all_data.size(1)] == false) return; rmsprop_once(all_data.data() + tid, @@ -57,7 +60,8 @@ __global__ void rmsprop_mask_step_kernel( all_grad.data() + tid, beta, lr, - epsilon); + epsilon, + minval); } __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM) @@ -68,7 +72,8 @@ __global__ void rmsprop_index_step_kernel( torch::PackedTensorAccessor32 indices, float beta, float lr, - float epsilon) { + float epsilon, + float minval) { CUDA_GET_THREAD_ID(tid, indices.size(0) * all_data.size(1)); int32_t i = indices[tid / all_data.size(1)]; int32_t j = tid % all_data.size(1); @@ -77,7 +82,8 @@ __global__ void rmsprop_index_step_kernel( all_grad.data() + off, beta, lr, - epsilon); + epsilon, + minval); } @@ -141,7 +147,8 @@ void rmsprop_step( torch::Tensor indexer, float beta, float lr, - float epsilon) { + float epsilon, + float minval) { DEVICE_GUARD(data); CHECK_INPUT(data); @@ -160,7 +167,8 @@ void rmsprop_step( grad.packed_accessor64(), beta, lr, - epsilon); + epsilon, + minval); } else if (indexer.size(0) == 0) { // Skip } else if (indexer.scalar_type() == at::ScalarType::Bool) { @@ -173,7 +181,8 @@ void rmsprop_step( indexer.data_ptr(), beta, lr, - epsilon); + epsilon, + minval); } else { const size_t Q = indexer.size(0) * data.size(1); const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); @@ -184,7 +193,8 @@ void rmsprop_step( indexer.packed_accessor32(), beta, lr, - epsilon); + epsilon, + minval); } CUDA_CHECK_ERRORS; diff --git a/svox2/csrc/render_lerp_kernel_cuvol.cu b/svox2/csrc/render_lerp_kernel_cuvol.cu index cebf97ca..04803204 100644 --- a/svox2/csrc/render_lerp_kernel_cuvol.cu +++ b/svox2/csrc/render_lerp_kernel_cuvol.cu @@ -68,9 +68,10 @@ __device__ __inline__ void trace_ray_cuvol( ray.pos[j] -= static_cast(ray.l[j]); } - float skip = compute_skip_dist(ray, - grid.links, grid.stride_x, - grid.size[2]); + const float skip = compute_skip_dist(ray, + grid.links, grid.stride_x, + grid.size[2], 0); + if (skip >= opt.step_size) { // For consistency, we skip the by step size t += ceilf(skip / opt.step_size) * opt.step_size; @@ -104,6 +105,10 @@ __device__ __inline__ void trace_ray_cuvol( lane_color, lane_colorgrp_id == 0); outv += weight * fmaxf(lane_color_total + 0.5f, 0.f); // Clamp to [+0, 1] if (_EXP(light_intensity) < opt.stop_thresh) { + const float renorm_val = 1.f / (1.f - _EXP(light_intensity)); + if (lane_colorgrp_id == 0) { + out[lane_colorgrp] *= renorm_val; + } break; } } @@ -153,9 +158,9 @@ __device__ __inline__ void trace_ray_cuvol_backward( ray.l[j] = min(static_cast(ray.pos[j]), grid.size[j] - 2); ray.pos[j] -= static_cast(ray.l[j]); } - float skip = compute_skip_dist(ray, - grid.links, grid.stride_x, - grid.size[2]); + const float skip = compute_skip_dist(ray, + grid.links, grid.stride_x, + grid.size[2], 0); if (skip >= opt.step_size) { // For consistency, we skip the by step size t += ceilf(skip / opt.step_size) * opt.step_size; @@ -261,6 +266,7 @@ __global__ void render_ray_kernel( ray_spec[ray_blk_id].set(rays.origins[ray_id].data(), rays.dirs[ray_id].data()); calc_sphfunc(grid, lane_id, + ray_id, ray_spec[ray_blk_id].dir, sphfunc_val[ray_blk_id]); if (lane_id == 0) { @@ -311,7 +317,9 @@ __global__ void render_ray_backward_kernel( if (lane_id < grid.basis_dim) { grad_sphfunc_val[ray_blk_id][lane_id] = 0.f; } - calc_sphfunc(grid, lane_id, vdir, sphfunc_val[ray_blk_id]); + calc_sphfunc(grid, lane_id, + ray_id, + vdir, sphfunc_val[ray_blk_id]); if (lane_id == 0) { ray_find_bounds(ray_spec[ray_blk_id], grid, opt); } @@ -329,8 +337,11 @@ __global__ void render_ray_backward_kernel( mask_out, grad_density_data_out, grad_sh_data_out); - calc_sphfunc_backward(grid, lane_id, + calc_sphfunc_backward( + grid, lane_id, + ray_id, vdir, + sphfunc_val[ray_blk_id], grad_sphfunc_val[ray_blk_id], grad_basis_data_out); } @@ -369,7 +380,8 @@ __global__ void render_ray_fused_kernel( if (lane_id < grid.basis_dim) { grad_sphfunc_val[ray_blk_id][lane_id] = 0.f; } - calc_sphfunc(grid, lane_id, vdir, sphfunc_val[ray_blk_id]); + calc_sphfunc(grid, lane_id, + ray_id, vdir, sphfunc_val[ray_blk_id]); if (lane_id == 0) { ray_find_bounds(ray_spec[ray_blk_id], grid, opt); } @@ -409,7 +421,8 @@ __global__ void render_ray_fused_kernel( } calc_sphfunc_backward(grid, lane_id, - vdir, + ray_id, vdir, + sphfunc_val[ray_blk_id], grad_sphfunc_val[ray_blk_id], grad_basis_data_out); } @@ -436,7 +449,7 @@ __global__ void render_image_kernel( TRACE_RAY_CUDA_RAYS_PER_BLOCK]; __shared__ SingleRaySpec ray_spec[TRACE_RAY_CUDA_RAYS_PER_BLOCK]; ray_spec[ray_blk_id].set(origin, dir); - calc_sphfunc(grid, lane_id, + calc_sphfunc(grid, lane_id, ray_id, dir, sphfunc_val[ray_blk_id]); if (lane_id == 0) { ray_find_bounds(ray_spec[ray_blk_id], grid, opt); @@ -484,7 +497,7 @@ __global__ void render_image_backward_kernel( if (lane_id < grid.basis_dim) { grad_sphfunc_val[ray_blk_id][lane_id] = 0.f; } - calc_sphfunc(grid, lane_id, + calc_sphfunc(grid, lane_id, ray_id, dir, sphfunc_val[ray_blk_id]); if (lane_id == 0) { ray_find_bounds(ray_spec[ray_blk_id], grid, opt); @@ -503,8 +516,9 @@ __global__ void render_image_backward_kernel( mask_out, grad_density_data_out, grad_sh_data_out); - calc_sphfunc_backward(grid, lane_id, + calc_sphfunc_backward(grid, lane_id, ray_id, dir, + sphfunc_val[ray_blk_id], grad_sphfunc_val[ray_blk_id], grad_basis_data_out); } diff --git a/svox2/csrc/svox2.cpp b/svox2/csrc/svox2.cpp index 7924da8d..021e2adf 100644 --- a/svox2/csrc/svox2.cpp +++ b/svox2/csrc/svox2.cpp @@ -42,7 +42,7 @@ void tv_grad_sparse(Tensor, Tensor, Tensor, Tensor, int, int, float, bool, float, bool, Tensor); // Optim -void rmsprop_step(Tensor, Tensor, Tensor, Tensor, float, float, float); +void rmsprop_step(Tensor, Tensor, Tensor, Tensor, float, float, float, float); void sgd_step(Tensor, Tensor, Tensor, float); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { @@ -78,7 +78,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def_readwrite("_offset", &SparseGridSpec::_offset) .def_readwrite("_scaling", &SparseGridSpec::_scaling) .def_readwrite("basis_dim", &SparseGridSpec::basis_dim) - .def_readwrite("use_learned_basis", &SparseGridSpec::use_learned_basis) + .def_readwrite("basis_type", &SparseGridSpec::basis_type) .def_readwrite("basis_data", &SparseGridSpec::basis_data); py::class_(m, "CameraSpec") diff --git a/svox2/defs.py b/svox2/defs.py new file mode 100644 index 00000000..6d7e320e --- /dev/null +++ b/svox2/defs.py @@ -0,0 +1,4 @@ +# Basis types (copied from C++ data_spec.hpp) +BASIS_TYPE_SH = 1 +BASIS_TYPE_3D_TEXTURE = 4 +BASIS_TYPE_MLP = 255 diff --git a/svox2/svox2.py b/svox2/svox2.py index aba6f4ac..d0ec9f56 100644 --- a/svox2/svox2.py +++ b/svox2/svox2.py @@ -1,3 +1,4 @@ +from .defs import * from .utils import ( isqrt, eval_sh_bases, @@ -5,6 +6,11 @@ is_pow2, spher2cart, eval_sg_at_dirs, + init_weights, + posenc, + net_to_dict, + net_from_dict, + convert_to_ndc, MAX_SH_BASIS, _get_c_extension, ) @@ -200,7 +206,10 @@ def backward(ctx, grad_out): cu_fn = _C.__dict__[f"volume_render_{ctx.backend}_backward"] grad_density_grid = torch.zeros_like(ctx.grid.density_data.data) grad_sh_grid = torch.zeros_like(ctx.grid.sh_data.data) - grad_basis = torch.zeros_like(ctx.grid.basis_data.data) + if ctx.grid.basis_type == BASIS_TYPE_MLP: + grad_basis = torch.zeros_like(ctx.grid.basis_data.data) + else: + grad_basis = torch.zeros_like(ctx.basis_data) # TODO save the sparse mask sparse_mask = cu_fn( ctx.grid, @@ -313,35 +322,49 @@ class SparseGrid(nn.Module): :param reso: int or List[int, int, int], resolution for resampled grid, as in the constructor :param radius: float or List[float, float, float], the 1/2 side length of the grid, optionally in each direction :param center: float or List[float, float, float], the center of the grid - :param use_learned_basis: bool, whether to use learned spherical function (false = SH) + :param basis_type: int, basis type; may use svox2.BASIS_TYPE_* (1 = SH, 4 = learned 3D texture, 255 = learned MLP) :param basis_dim: int, size of basis / number of SH components (must be square number in case of SH) - :param basis_reso: int, resolution of learned spherical function + :param basis_reso: int, resolution of grid if using BASIS_TYPE_3D_TEXTURE :param use_z_order: bool, if true, stores the data initially in a Z-order curve if possible + :param mlp_posenc_size: int, if using BASIS_TYPE_MLP, then enables standard axis-aligned positional encoding of + given size on MLP; if 0 then does not use positional encoding + :param mlp_width: int, if using BASIS_TYPE_MLP, specifies MLP width (hidden dimension) :param device: torch.device, device to store the grid """ def __init__( self, - reso: Union[int, List[int]] = 128, + reso: Union[int, List[int], Tuple[int, int, int]] = 128, radius: Union[float, List[float]] = 1.0, center: Union[float, List[float]] = [0.0, 0.0, 0.0], - use_learned_basis: bool = True, + basis_type: int = BASIS_TYPE_SH, basis_dim: int = 9, # SH/learned basis size; in SH case, square number basis_reso: int = 16, # Learned basis resolution (x^3 embedding grid) use_z_order : bool=False, use_sphere_bound : bool=False, + mlp_posenc_size : int = 0, + mlp_width : int = 16, + background_nlayers : int = 0, # BG MSI layers + background_reso : int = 256, # BG MSI cubemap face size device: Union[torch.device, str] = "cpu", ): super().__init__() - self.use_learned_basis = use_learned_basis - if not use_learned_basis: + self.basis_type = basis_type + if basis_type == BASIS_TYPE_SH: assert isqrt(basis_dim) is not None, "basis_dim (SH) must be a square number" assert ( basis_dim >= 1 and basis_dim <= MAX_SH_BASIS ), f"basis_dim 1-{MAX_SH_BASIS} supported" self.basis_dim = basis_dim + self.mlp_posenc_size = mlp_posenc_size + self.mlp_width = mlp_width + + self.background_nlayers = background_nlayers + self.background_reso = background_reso + self.use_background = background_nlayers > 0 + if isinstance(reso, int): reso = [reso] * 3 else: @@ -409,28 +432,65 @@ def __init__( torch.zeros(self.capacity, 1, dtype=torch.float32, device=device) ) # Called sh for legacy reasons, but it's just the coeffients for whatever - # representation + # spherical basis functions self.sh_data = nn.Parameter( torch.zeros( self.capacity, self.basis_dim * 3, dtype=torch.float32, device=device ) ) - if use_learned_basis: + self.basis_data = nn.Parameter( + torch.zeros( + (0, 0, 0, 0), dtype=torch.float32, device=device + ), + requires_grad=False + ) + if self.basis_type == BASIS_TYPE_3D_TEXTURE: # Unit sphere embedded in a cube self.basis_data = nn.Parameter( torch.zeros( basis_reso, basis_reso, basis_reso, - self.basis_dim, dtype=torch.float32, device=device + self.basis_dim - 1, dtype=torch.float32, device=device ) ) - else: - self.basis_data = nn.Parameter( + elif self.basis_type == BASIS_TYPE_MLP: + D_rgb = mlp_width + dir_in_dims = 3 + 6 * self.mlp_posenc_size + # Hard-coded 4 layer MLP + self.basis_mlp = nn.Sequential( + nn.Linear(dir_in_dims, D_rgb), + nn.ReLU(), + nn.Linear(D_rgb, D_rgb), + nn.ReLU(), + nn.Linear(D_rgb, D_rgb), + nn.ReLU(), + nn.Linear(D_rgb, self.basis_dim) + ) + self.basis_mlp = self.basis_mlp.to(device=self.sh_data.device) + self.basis_mlp.apply(init_weights) + + if self.use_background: + bg_total = self.background_reso * self.background_reso * self.background_nlayers + self.background_density_data = nn.Parameter( + torch.zeros( + bg_total, + 1, dtype=torch.float32, device=device + ) + ) + self.background_sh_data = nn.Parameter( torch.zeros( - (0, 0, 0, 0), dtype=torch.float32, device=device - ), - requires_grad=False + bg_total, + self.basis_dim * 3, dtype=torch.float32, device=device + ) ) + init_bg_links = torch.arange(bg_total, device=device, dtype=torch.int32).view + self.register_buffer("background_links", + init_bg_links.view( + 2, + 3, + self.background_reso, + self.background_reso, + self.background_nlayers)) self.register_buffer("links", init_links.view(reso)) self.links: torch.Tensor @@ -455,10 +515,10 @@ def data_dim(self): @property def basis_reso(self): """ - Return the resolution of the learned spherical function data, - or 0 if only using SH + Return the resolution of the learned spherical function data if using + 3D learned texture, or 0 if only using SH """ - return self.basis_data.size(0) if self.use_learned_basis else 0 + return self.basis_data.size(0) if self.BASIS_TYPE_3D_TEXTURE else 0 @property def shape(self): @@ -504,7 +564,7 @@ def sample(self, points: torch.Tensor, if use_kernel and self.links.is_cuda and _C is not None: assert points.is_cuda return _SampleGridAutogradFunction.apply( - self.density_data, self.sh_data, self._to_cpp(grid_coords), points, want_colors + self.density_data, self.sh_data, self._to_cpp(grid_coords=grid_coords), points, want_colors ) else: if not grid_coords: @@ -575,8 +635,10 @@ def _volume_render_gradcheck_lerp(self, rays: Rays): delta_scale = 1.0 / dirs.norm(dim=1) dirs *= delta_scale.unsqueeze(-1) - if self.use_learned_basis: + if self.basis_type == BASIS_TYPE_3D_TEXTURE: sh_mult = self._eval_learned_bases(viewdirs) + elif self.basis_type == BASIS_TYPE_MLP: + sh_mult = self._eval_basis_mlp(viewdirs) else: sh_mult = eval_sh_bases(self.basis_dim, viewdirs) invdirs = 1.0 / dirs @@ -700,11 +762,13 @@ def volume_render( assert self.opt.backend in ["cuvol"] # , 'lerp', 'nn'] if use_kernel and self.links.is_cuda and _C is not None: assert rays.is_cuda + basis_data = self._eval_basis_mlp(rays.dirs) if self.basis_type == BASIS_TYPE_MLP \ + else None return _VolumeRenderFunction.apply( self.density_data, self.sh_data, - self.basis_data, - self._to_cpp(), + basis_data, + self._to_cpp(replace_basis_data=basis_data), rays._to_cpp(), self.opt._to_cpp(randomize=randomize), self.opt.backend, @@ -735,8 +799,13 @@ def volume_render_fused( assert rays.is_cuda grad_density, grad_sh, grad_basis = self._get_data_grads() rgb_out = torch.zeros_like(rgb_gt) + basis_data : Optional[torch.Tensor] = None + if self.basis_type == BASIS_TYPE_MLP: + with torch.enable_grad(): + basis_data = self._eval_basis_mlp(rays.dirs) + grad_basis = torch.empty_like(basis_data) self.sparse_grad_indexer: torch.Tensor = _C.volume_render_cuvol_fused( - self._to_cpp(), + self._to_cpp(replace_basis_data=basis_data), rays._to_cpp(), self.opt._to_cpp(), rgb_gt, @@ -745,11 +814,16 @@ def volume_render_fused( grad_sh, grad_basis, ) + if self.basis_type == BASIS_TYPE_MLP: + # Manually trigger MLP backward! + basis_data.backward(grad_basis) + self.sparse_sh_grad_indexer = self.sparse_grad_indexer.clone() return rgb_out def volume_render_image( - self, camera: Camera, use_kernel: bool = True, randomize: bool = False + self, camera: Camera, use_kernel: bool = True, randomize: bool = False, + batch_size : int = 5000, ): """ Standard volume rendering (entire image version). @@ -761,21 +835,55 @@ def volume_render_image( :return: (H, W, 3), predicted RGB image """ assert self.opt.backend in ["cuvol"] # , 'lerp', 'nn'] - if use_kernel and self.links.is_cuda and _C is not None: - assert camera.is_cuda - return _VolumeRenderImageFunction.apply( - self.density_data, - self.sh_data, - self.basis_data, - self._to_cpp(), - camera._to_cpp(), - self.opt._to_cpp(randomize=randomize), - self.opt.backend, - ) - else: - raise NotImplementedError( - "Pure PyTorch image rendering not implemented, " + "please use rays" - ) + assert camera.ndc_coeffs[0] < 0.0, "To be impl" + # For now we're just generating the rays (due to MLP the _VolumeRenderImageFunction no longer always works) + + origins = camera.c2w[None, :3, 3].expand(camera.height * camera.width, -1).contiguous() + yy, xx = torch.meshgrid( + torch.arange(camera.height, dtype=torch.float64, device=camera.c2w.device) + 0.5, + torch.arange(camera.width, dtype=torch.float64, device=camera.c2w.device) + 0.5, + ) + xx = (xx - camera.cx) / camera.fx + yy = (yy - camera.cy) / camera.fy + zz = torch.ones_like(xx) + dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV + del xx, yy, zz + dirs /= torch.norm(dirs, dim=-1, keepdim=True) + dirs = dirs.reshape(-1, 3, 1) + dirs = (camera.c2w[None, :3, :3].double() @ dirs)[..., 0] + dirs = dirs.reshape(-1, 3).float() + + if camera.ndc_coeffs[0] > 0.0: + origins, dirs = convert_to_ndc( + origins, + dirs, + self.ndc_coeffs) + dirs /= torch.norm(dirs, dim=-1, keepdim=True) + + all_rgb_out = [] + for batch_start in range(0, camera.height * camera.width, batch_size): + rays = Rays(origins[batch_start:batch_start+batch_size], dirs[batch_start:batch_start+batch_size]) + rgb_out_part = self.volume_render(rays, use_kernel=use_kernel, randomize=randomize) + all_rgb_out.append(rgb_out_part) + + all_rgb_out = torch.cat(all_rgb_out, dim=0) + return all_rgb_out.view(camera.height, camera.width, -1) + + # if use_kernel and self.links.is_cuda and _C is not None: + # assert camera.is_cuda + # return _VolumeRenderImageFunction.apply( + # self.density_data, + # self.sh_data, + # self.basis_data, + # self._to_cpp(replace_basis_data=basis_data), + # camera._to_cpp(), + # self.opt._to_cpp(randomize=randomize), + # self.opt.backend, + # ) + # else: + # raise NotImplementedError( + # "Pure PyTorch image rendering not implemented, " + "please use rays" + # ) def resample( self, @@ -1027,8 +1135,14 @@ def save(self, path: str, compress: bool = False): "density_data":self.density_data.data.cpu().numpy(), "sh_data":self.sh_data.data.cpu().numpy().astype(np.float16), } - if self.use_learned_basis: + if self.basis_type == BASIS_TYPE_3D_TEXTURE: data['basis_data'] = self.basis_data.data.cpu().numpy() + elif self.basis_type == BASIS_TYPE_MLP: + net_to_dict(data, "basis_mlp", self.basis_mlp) + data['mlp_posenc_size'] = np.int32(self.mlp_posenc_size) + data['mlp_width'] = np.int32(self.mlp_width) + data['basis_type'] = self.basis_type + save_fn( path, **data @@ -1059,7 +1173,9 @@ def load(cls, path: str, device: Union[torch.device, str] = "cpu"): basis_dim=basis_dim, use_z_order=False, device="cpu", - use_learned_basis=False + basis_type=z['basis_type'].item() if 'basis_type' in z else BASIS_TYPE_SH, + mlp_posenc_size=z['mlp_posenc_size'].item() if 'mlp_posenc_size' in z else 0, + mlp_width=z['mlp_width'].item() if 'mlp_width' in z else 16 ) if sh_data.dtype != np.float32: sh_data = sh_data.astype(np.float32) @@ -1073,9 +1189,15 @@ def load(cls, path: str, device: Union[torch.device, str] = "cpu"): grid.capacity = grid.sh_data.size(0) # Maybe load basis_data - if "basis_data" in z.keys(): + if grid.basis_type == BASIS_TYPE_MLP: + net_from_dict(z, "basis_mlp", grid.basis_mlp) + grid.basis_mlp = grid.basis_mlp.to(device=device) + elif grid.basis_type == BASIS_TYPE_3D_TEXTURE or \ + "basis_data" in z.keys(): + # Note: Checking for basis_data for compatibility with earlier vers + # where basis_type not stored basis_data = torch.from_numpy(z.f.basis_data).to(device=device) - grid.use_learned_basis = True + grid.basis_type = BASIS_TYPE_3D_TEXTURE grid.basis_data = nn.Parameter(basis_data) else: grid.basis_data = nn.Parameter(grid.basis_data.data.to(device=device)) @@ -1252,6 +1374,41 @@ def inplace_tv_color_grad( grad) self.sparse_sh_grad_indexer = None + def inplace_l2_color_grad( + self, + grad: torch.Tensor, + start_dim: int = 0, + end_dim: Optional[int] = None, + scaling: float = 1.0, + ): + """ + Add gradient of L2 regularization for color + directly into the gradient tensor, multiplied by 'scaling' + + :param start_dim: int, first color channel dimension to compute TV over (inclusive). + Default 0. + :param end_dim: int, last color channel dimension to compute TV over (exclusive). + Default None = all dimensions until the end. + """ + assert ( + _C is not None and self.sh_data.is_cuda and grad.is_cuda + ), "CUDA extension is currently required for total variation" + with torch.no_grad(): + if end_dim is None: + end_dim = self.sh_data.size(1) + end_dim = end_dim + self.sh_data.size(1) if end_dim < 0 else end_dim + start_dim = start_dim + self.sh_data.size(1) if start_dim < 0 else start_dim + + if self.sparse_sh_grad_indexer is None: + scale = scaling / self.sh_data.size(0) + grad[:, start_dim:end_dim] += scale * self.sh_data[:, start_dim:end_dim] + else: + indexer = self._maybe_convert_sparse_grad_indexer(sh=True) + nz : int = torch.count_nonzero(indexer).item() if indexer.dtype == torch.bool else \ + indexer.size(0) + scale = scaling / nz + grad[indexer, start_dim:end_dim] += scale * self.sh_data[indexer, start_dim:end_dim] + def inplace_tv_basis_grad( self, grad: torch.Tensor, @@ -1274,7 +1431,7 @@ def optim_density_step(self, lr: float, beta: float=0.9, epsilon: float = 1e-8, _C is not None and self.sh_data.is_cuda ), "CUDA extension is currently required for optimizers" - self._maybe_convert_sparse_grad_indexer() + indexer = self._maybe_convert_sparse_grad_indexer() if optim == 'rmsprop': if ( self.density_rms is None @@ -1286,16 +1443,17 @@ def optim_density_step(self, lr: float, beta: float=0.9, epsilon: float = 1e-8, self.density_data.data, self.density_rms, self.density_data.grad, - self._get_sparse_grad_indexer(), + indexer, beta, lr, epsilon, + -1e9 ) elif optim == 'sgd': _C.sgd_step( self.density_data.data, self.density_data.grad, - self._get_sparse_grad_indexer(), + indexer, lr, ) else: @@ -1310,7 +1468,7 @@ def optim_sh_step(self, lr: float, beta: float=0.9, epsilon: float = 1e-8, _C is not None and self.sh_data.is_cuda ), "CUDA extension is currently required for optimizers" - self._maybe_convert_sparse_sh_grad_indexer() + indexer = self._maybe_convert_sparse_grad_indexer(sh=True) if optim == 'rmsprop': if self.sh_rms is None or self.sh_rms.shape != self.sh_data.shape: del self.sh_rms @@ -1319,14 +1477,15 @@ def optim_sh_step(self, lr: float, beta: float=0.9, epsilon: float = 1e-8, self.sh_data.data, self.sh_rms, self.sh_data.grad, - self._get_sparse_sh_grad_indexer(), + indexer, beta, lr, epsilon, + -1e9 ) elif optim == 'sgd': _C.sgd_step( - self.sh_data.data, self.sh_data.grad, self._get_sparse_sh_grad_indexer(), lr + self.sh_data.data, self.sh_data.grad, indexer, lr ) else: raise NotImplementedError(f'Unsupported optimizer {optim}') @@ -1354,9 +1513,20 @@ def optim_basis_step(self, lr: float, beta: float=0.9, epsilon: float = 1e-8, raise NotImplementedError(f'Unsupported optimizer {optim}') self.basis_data.grad.zero_() + @property + def basis_type_name(self): + if self.basis_type == BASIS_TYPE_SH: + return "SH" + elif self.basis_type == BASIS_TYPE_3D_TEXTURE: + return "3D_TEXTURE" + elif self.basis_type == BASIS_TYPE_MLP: + return "MLP" + return "UNKNOWN" + def __repr__(self): return ( - f"svox2.SparseGrid(basis_dim={self.basis_dim}, " + f"svox2.SparseGrid(basis_type={self.basis_type_name}, " + + f"basis_dim={self.basis_dim}, " + f"reso={list(self.links.shape)}, " + f"capacity:{self.sh_data.size(0)})" ) @@ -1369,7 +1539,7 @@ def is_cubic_pow2(self): reso = self.links.shape return reso[0] == reso[1] and reso[0] == reso[2] and is_pow2(reso[0]) - def _to_cpp(self, grid_coords: bool = False): + def _to_cpp(self, grid_coords: bool = False, replace_basis_data: Optional[torch.Tensor] = None): """ Generate object to pass to C++ """ @@ -1386,8 +1556,8 @@ def _to_cpp(self, grid_coords: bool = False): gspec._scaling = self._scaling * gsz gspec.basis_dim = self.basis_dim - gspec.use_learned_basis = self.use_learned_basis - gspec.basis_data = self.basis_data + gspec.basis_type = self.basis_type + gspec.basis_data = replace_basis_data if replace_basis_data is not None else self.basis_data return gspec def _grid_size(self): @@ -1423,37 +1593,22 @@ def _get_sparse_sh_grad_indexer(self): indexer = torch.empty((), device=self.density_data.device) return indexer - def _maybe_convert_sparse_grad_indexer(self): + def _maybe_convert_sparse_grad_indexer(self, sh=False): """ Automatically convert sparse grad indexer from mask to indices, if it is efficient """ + indexer = self.sparse_sh_grad_indexer if sh else self.sparse_grad_indexer + if indexer is None: + return torch.empty((), device=self.density_data.device) if ( - self.sparse_grad_indexer is not None and - self.sparse_grad_indexer.dtype == torch.bool and - torch.count_nonzero(self.sparse_grad_indexer).item() - < self.sparse_grad_indexer.size(0) // 8 - ): - # Highly sparse (use index) - self.sparse_grad_indexer = torch.nonzero( - self.sparse_grad_indexer.flatten(), as_tuple=False - ).flatten() - - def _maybe_convert_sparse_sh_grad_indexer(self): - """ - Automatically convert sparse grad indexer from mask to - indices, if it is efficient - """ - if ( - self.sparse_sh_grad_indexer is not None and - self.sparse_sh_grad_indexer.dtype == torch.bool and - torch.count_nonzero(self.sparse_sh_grad_indexer).item() - < self.sparse_sh_grad_indexer.size(0) // 8 + indexer.dtype == torch.bool and + torch.count_nonzero(indexer).item() + < indexer.size(0) // 8 ): # Highly sparse (use index) - self.sparse_sh_grad_indexer = torch.nonzero( - self.sparse_sh_grad_indexer.flatten(), as_tuple=False - ).flatten() + indexer = torch.nonzero(indexer.flatten(), as_tuple=False).flatten() + return indexer def _get_rand_cells(self, sparse_frac: float): if sparse_frac < 1.0: @@ -1469,10 +1624,24 @@ def _eval_learned_bases(self, dirs: torch.Tensor): basis_data = self.basis_data.permute([3, 2, 1, 0])[None] samples = F.grid_sample(basis_data, dirs[None, None, None], mode='bilinear', padding_mode='zeros', align_corners=True) samples = samples[0, :, 0, 0, :].permute([1, 0]) - # dc = torch.ones_like(samples[:, :1]) - # samples = torch.cat([dc, samples], dim=-1) + dc = torch.full_like(samples[:, :1], fill_value=0.28209479177387814) + samples = torch.cat([dc, samples], dim=-1) return samples + def _eval_basis_mlp(self, dirs: torch.Tensor): + if self.mlp_posenc_size > 0: + dirs_enc = posenc( + dirs, + None, + 0, + self.mlp_posenc_size, + include_identity=True, + enable_ipe=False, + ) + else: + dirs_enc = dirs + return self.basis_mlp(dirs_enc) + def reinit_learned_bases(self, init_type: str = 'sh', sg_lambda_max: float = 1.0, @@ -1531,7 +1700,7 @@ def reinit_learned_bases(self, u2 = torch.arange(0, n_comps) + torch.rand((n_comps,)) u2 /= n_comps fourier_dirvecs = spher2cart(u1 * np.pi, u2 * np.pi * 2) - fourier_freqs = torch.linspace(0.0, 2.0, n_comps + 1)[:-1] + fourier_freqs = torch.linspace(0.0, 1.0, n_comps + 1)[:-1] fourier_freqs += torch.rand_like(fourier_freqs) * (fourier_freqs[1] - fourier_freqs[0]) fourier_freqs = torch.exp(fourier_freqs) fourier_freqs = fourier_freqs[torch.randperm(n_comps)] diff --git a/svox2/utils.py b/svox2/utils.py index 7b744ba1..5403d4a1 100644 --- a/svox2/utils.py +++ b/svox2/utils.py @@ -1,5 +1,8 @@ from functools import partial import torch +from torch import nn +from typing import Optional +import numpy as np def inthroot(x : int, n : int): if x <= 0: @@ -272,3 +275,140 @@ def eval_sg_at_dirs(sg_lambda : torch.Tensor, sg_mu : torch.Tensor, dirs : torch basis = torch.exp(torch.einsum( "i,...i->...i", sg_lambda, product - 1)) # [..., N] return basis + +def init_weights(m): + if type(m) == nn.Linear: + nn.init.xavier_uniform_(m.weight) + m.bias.data.fill_(0.0) + + +def cross_broadcast(x : torch.Tensor, y : torch.Tensor): + """ + Cross broadcasting for 2 tensors + + :param x: torch.Tensor + :param y: torch.Tensor, should have the same ndim as x + :return: tuple of cross-broadcasted tensors x, y. Any dimension where the size of x or y is 1 + is expanded to the maximum size in that dimension among the 2. + Formally, say the shape of x is (a1, ... an) + and of y is (b1, ... bn); + then the result has shape (a'1, ... a'n), (b'1, ... b'n) + where + :code:`a'i = ai if (ai > 1 and bi > 1) else max(ai, bi)` + :code:`b'i = bi if (ai > 1 and bi > 1) else max(ai, bi)` + """ + assert x.ndim == y.ndim, "Only available if ndim is same for all tensors" + max_shape = [(-1 if (a > 1 and b > 1) else max(a,b)) for i, (a, b) + in enumerate(zip(x.shape, y.shape))] + shape_x = [max(a, m) for m, a in zip(max_shape, x.shape)] + shape_y = [max(b, m) for m, b in zip(max_shape, y.shape)] + x = x.broadcast_to(shape_x) + y = y.broadcast_to(shape_y) + return x, y + +def posenc( + x: torch.Tensor, + cov_diag: Optional[torch.Tensor], + min_deg: int, + max_deg: int, + include_identity: bool = True, + enable_ipe: bool = True, + cutoff: float = 1.0, +): + """ + Positional encoding function. Adapted from jaxNeRF + (https://github.com/google-research/google-research/tree/master/jaxnerf). + With support for mip-NeFF IPE (by passing cov_diag != 0, keeping enable_ipe=True). + And BARF-nerfies frequency attenuation (setting cutoff) + + Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1], + Instead of computing [sin(x), cos(x)], we use the trig identity + cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]). + + :param x: torch.Tensor (..., D), variables to be encoded. Note that x should be in [-pi, pi]. + :param cov_diag: torch.Tensor (..., D), diagonal cov for each variable to be encoded (IPE) + :param min_deg: int, the minimum (inclusive) degree of the encoding. + :param max_deg: int, the maximum (exclusive) degree of the encoding. if min_deg >= max_deg, + positional encoding is disabled. + :param include_identity: bool, if true then concatenates the identity + :param enable_ipe: bool, if true then uses cov_diag to compute IPE, if available. + Note cov_diag = 0 will give the same effect. + :param cutoff: float, in [0, 1], a relative frequency cutoff as in BARF/nerfies. 1 = all frequencies, + 0 = no frequencies + + :return: encoded torch.Tensor (..., D * (max_deg - min_deg) * 2 [+ D if include_identity]), + encoded variables. + """ + if min_deg >= max_deg: + return x + scales = torch.tensor([2 ** i for i in range(min_deg, max_deg)], device=x.device) + half_enc_dim = x.shape[-1] * scales.shape[0] + shapeb = list(x.shape[:-1]) + [half_enc_dim] # (..., D * (max_deg - min_deg)) + xb = torch.reshape((x[..., None, :] * scales[:, None]), shapeb) + four_feat = torch.sin( + torch.cat([xb, xb + 0.5 * np.pi], dim=-1) + ) # (..., D * (max_deg - min_deg) * 2) + if enable_ipe and cov_diag is not None: + # Apply integrated positional encoding (IPE) + xb_var = torch.reshape((cov_diag[..., None, :] * scales[:, None] ** 2), shapeb) + xb_var = torch.tile(xb_var, (2,)) # (..., D * (max_deg - min_deg) * 2) + four_feat = four_feat * torch.exp(-0.5 * xb_var) + if cutoff < 1.0: + # BARF/nerfies, could be made cleaner + cutoff_mask = _cutoff_mask( + scales, cutoff * (max_deg - min_deg) + ) # (max_deg - min_deg,) + four_feat = four_feat.view(shapeb[:-1] + [2, scales.shape[0], x.shape[-1]]) + four_feat = four_feat * cutoff_mask[..., None] + four_feat = four_feat.view(shapeb[:-1] + [2 * scales.shape[0] * x.shape[-1]]) + if include_identity: + four_feat = torch.cat([x] + [four_feat], dim=-1) + return four_feat + + +def net_to_dict(out_dict : dict, + prefix : str, + model : nn.Module): + for child in model.named_children(): + layer_name = child[0] + layer_params = {} + for param in child[1].named_parameters(): + param_name = param[0] + param_value = param[1].data.cpu().numpy() + out_dict['pt__' + prefix + '__' + layer_name + '__' + param_name] = param_value + +def net_from_dict(in_dict, + prefix : str, + model : nn.Module): + for child in model.named_children(): + layer_name = child[0] + layer_params = {} + for param in child[1].named_parameters(): + param_name = param[0] + value = in_dict['pt__' + prefix + '__' + layer_name + '__' + param_name] + param_value = param[1].data[:] = torch.from_numpy(value).to( + device=param[1].data.device) + + +def convert_to_ndc(origins, directions, ndc_coeffs, near: float = 1.0): + """Convert a set of rays to NDC coordinates.""" + # Shift ray origins to near plane, not sure if needed + t = (near - origins[Ellipsis, 2]) / directions[Ellipsis, 2] + origins = origins + t[Ellipsis, None] * directions + + dx, dy, dz = directions.unbind(-1) + ox, oy, oz = origins.unbind(-1) + + # Projection + o0 = ndc_coeffs[0] * (ox / oz) + o1 = ndc_coeffs[1] * (oy / oz) + o2 = 1 - 2 * near / oz + + d0 = ndc_coeffs[0] * (dx / dz - ox / oz) + d1 = ndc_coeffs[1] * (dy / dz - oy / oz) + d2 = 2 * near / oz; + + origins = torch.stack([o0, o1, o2], -1) + directions = torch.stack([d0, d1, d2], -1) + return origins, directions + diff --git a/test/sanity.py b/test/sanity.py index 66d6b430..9c746795 100644 --- a/test/sanity.py +++ b/test/sanity.py @@ -6,7 +6,8 @@ torch.random.manual_seed(123) g = svox2.SparseGrid(center=[0.0, 0.0, 0.0], radius=[1.0, 1.0, 1.0], - device=device) + device=device, + basis_type=svox2.BASIS_TYPE_MLP) g.opt.sigma_thresh = 0.0 g.opt.stop_thresh = 0.0 @@ -17,7 +18,7 @@ g.sh_data.data[..., 1:] = 0.0 g.basis_data.data.normal_() g.basis_data.data *= 10.0 -print('use frustum?', g.use_frustum) +# print('use frustum?', g.use_frustum) N_RAYS = 1