Skip to content

Commit

Permalink
Save before adding BG model
Browse files Browse the repository at this point in the history
  • Loading branch information
sxyu committed Nov 1, 2021
1 parent e21fb58 commit 878ac55
Show file tree
Hide file tree
Showing 17 changed files with 667 additions and 195 deletions.
1 change: 1 addition & 0 deletions manual_install.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
cp svox2/svox2.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/svox2.py
cp svox2/utils.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/utils.py
cp svox2/version.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/version.py
cp svox2/defs.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/defs.py
cp svox2/__init__.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/__init__.py
66 changes: 46 additions & 20 deletions opt/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# or use launching script: sh launch.sh <EXP_NAME> <GPU> <DATA_DIR>
import torch
import torch.cuda
import torch.optim
import torch.nn.functional as F
import svox2
import json
Expand Down Expand Up @@ -64,12 +65,12 @@

# TODO: make the lr higher near the end
group.add_argument('--sigma_optim', choices=['sgd', 'rmsprop'], default='rmsprop', help="Density optimizer")
group.add_argument('--lr_sigma', type=float, default=
3e1,
group.add_argument('--lr_sigma', type=float, default=3e1,
help='SGD/rmsprop lr for sigma')
group.add_argument('--lr_sigma_final', type=float, default=5e-2)
group.add_argument('--lr_sigma_decay_steps', type=int, default=250000)
group.add_argument('--lr_sigma_delay_steps', type=int, default=15000,
group.add_argument('--lr_sigma_delay_steps', type=int, default=
15000,
help="Reverse cosine steps (0 means disable)")
group.add_argument('--lr_sigma_delay_mult', type=float, default=1e-2)

Expand All @@ -90,11 +91,11 @@

group.add_argument('--basis_optim', choices=['sgd', 'rmsprop'], default='rmsprop', help="Learned basis optimizer")
group.add_argument('--lr_basis', type=float, default=#2e6,
1e-6,
1e-3,
help='SGD/rmsprop lr for SH')
group.add_argument('--lr_basis_final', type=float,
default=
1e-6
1e-4
)
group.add_argument('--lr_basis_decay_steps', type=int, default=250000)
group.add_argument('--lr_basis_delay_steps', type=int, default=0,#15000,
Expand Down Expand Up @@ -140,16 +141,20 @@
group.add_argument('--tv_logalpha', action='store_true', default=False,
help='Use log(1-exp(-delta * sigma)) as in neural volumes')

group.add_argument('--lambda_tv_sh', type=float, default=1e-2)
group.add_argument('--lambda_tv_sh', type=float, default=0.0) # FIXME
group.add_argument('--tv_sh_sparsity', type=float, default=0.01)

group.add_argument('--lambda_l2_sh', type=float, default=1e-4) # FIXME

group.add_argument('--lambda_tv_basis', type=float, default=0.0)

group.add_argument('--weight_decay_sigma', type=float, default=1.0)
group.add_argument('--weight_decay_sh', type=float, default=1.0)

group.add_argument('--lr_decay', action='store_true', default=True)
group.add_argument('--use_learned_basis', action='store_true', default=False)
group.add_argument('--basis_type',
choices=['sh', '3d_texture', 'mlp'],
default='sh')
args = parser.parse_args()

assert args.lr_sigma_final <= args.lr_sigma, "lr_sigma must be >= lr_sigma_final"
Expand Down Expand Up @@ -189,7 +194,8 @@
use_z_order=True,
device=device,
basis_reso=args.basis_reso,
use_learned_basis=args.use_learned_basis)
basis_type=svox2.__dict__['BASIS_TYPE_' + args.basis_type.upper()],
mlp_posenc_size=4)

grid.opt.last_sample_opaque = dset.last_sample_opaque

Expand All @@ -205,12 +211,21 @@
# # den[:, :, 0] = 1e9
# grid.density_data.data = den.view(osh)

if grid.use_learned_basis:
grid.reinit_learned_bases(init_type='sh')
# grid.reinit_learned_bases(init_type='fourier')
# grid.reinit_learned_bases(init_type='sg', upper_hemi=True)
# grid.basis_data.data.normal_(mean=0.0, std=0.01)
# grid.basis_data.data += 0.28209479177387814
optim_basis_mlp = None

if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE:
# grid.reinit_learned_bases(init_type='sh')
# grid.reinit_learned_bases(init_type='fourier')
# grid.reinit_learned_bases(init_type='sg', upper_hemi=True)
grid.basis_data.data.normal_(mean=0.28209479177387814, std=0.001)

elif grid.basis_type == svox2.BASIS_TYPE_MLP:
# MLP!
optim_basis_mlp = torch.optim.Adam(
grid.basis_mlp.parameters(),
lr=args.lr_basis
)


grid.requires_grad_(True)
step_size = 0.5 # 0.5 of a voxel!
Expand Down Expand Up @@ -303,14 +318,18 @@ def eval_step():
stats_test['psnr'] += psnr
n_images_gen += 1

if grid.use_learned_basis:
# Add spherical map visualization
if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE or \
grid.basis_type == svox2.BASIS_TYPE_MLP:
# Add spherical map visualization
EQ_RESO = 256
eq_dirs = generate_dirs_equirect(EQ_RESO * 2, EQ_RESO)
eq_dirs = torch.from_numpy(eq_dirs).to(device=device).view(-1, 3)

sphfuncs = grid._eval_learned_bases(eq_dirs).view(EQ_RESO, EQ_RESO*2, -1)
sphfuncs = sphfuncs.permute([2, 0, 1]).cpu().numpy()
if grid.basis_type == svox2.BASIS_TYPE_MLP:
sphfuncs = grid._eval_basis_mlp(eq_dirs)
else:
sphfuncs = grid._eval_learned_bases(eq_dirs)
sphfuncs = sphfuncs.view(EQ_RESO, EQ_RESO*2, -1).permute([2, 0, 1]).cpu().numpy()

stats = [(sphfunc.min(), sphfunc.mean(), sphfunc.max())
for sphfunc in sphfuncs]
Expand Down Expand Up @@ -404,6 +423,9 @@ def train_step():
grid.inplace_tv_color_grad(grid.sh_data.grad,
scaling=args.lambda_tv_sh,
sparse_frac=args.tv_sh_sparsity)
if args.lambda_l2_sh > 0.0:
grid.inplace_l2_color_grad(grid.sh_data.grad,
scaling=args.lambda_l2_sh)
if args.lambda_tv_basis > 0.0:
tv_basis = grid.tv_basis()
loss_tv_basis = tv_basis * args.lambda_tv_basis
Expand All @@ -414,8 +436,12 @@ def train_step():
# Manual SGD/rmsprop step
grid.optim_density_step(lr_sigma, beta=args.rms_beta, optim=args.sigma_optim)
grid.optim_sh_step(lr_sh, beta=args.rms_beta, optim=args.sh_optim)
if grid.use_learned_basis and gstep_id >= args.lr_basis_begin_step:
grid.optim_basis_step(lr_basis, beta=args.rms_beta, optim=args.basis_optim)
if gstep_id >= args.lr_basis_begin_step:
if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE:
grid.optim_basis_step(lr_basis, beta=args.rms_beta, optim=args.basis_optim)
elif grid.basis_type == svox2.BASIS_TYPE_MLP:
optim_basis_mlp.step()
optim_basis_mlp.zero_grad()

train_step()
gc.collect()
Expand Down
12 changes: 6 additions & 6 deletions opt/render_imgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@
img_eval_interval = max(n_images // args.n_eval, 1)
avg_psnr = 0.0
n_images_gen = 0
cam = svox2.Camera(torch.tensor(0), dset.intrins.fx, dset.intrins.fy,
dset.intrins.cx, dset.intrins.cy,
dset.w, dset.h,
ndc_coeffs=dset.ndc_coeffs)
c2ws = dset.render_c2w.to(device=device) if args.render_path else dset.c2w.to(device=device)
for img_id in tqdm(range(0, n_images, img_eval_interval)):
c2w = dset.render_c2w[img_id] if args.render_path else dset.c2w[img_id]
c2w = c2w.to(device=device)
cam = svox2.Camera(c2w, dset.intrins.fx, dset.intrins.fy,
dset.intrins.cx, dset.intrins.cy,
dset.w, dset.h,
ndc_coeffs=dset.ndc_coeffs)
cam.c2w = c2ws[img_id]
im = grid.volume_render_image(cam, use_kernel=True)
im.clamp_(0.0, 1.0)
if not args.render_path:
Expand Down
16 changes: 9 additions & 7 deletions opt/tasks/sanity.json
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
{
"data_root": "/home/sxyu/data/nerf_synthetic/lego",
"train_root": "/home/sxyu/proj/svox2/opt/ckpt_tune/fast_256_sweepup_and_lrcdec_10epoch",
"train_root": "/home/sxyu/proj/svox2/opt/ckpt_tune/mlp_sweep",
"variables": {
"lr_sh_final": "loglin(1e-5, 5e-2, 10)",
"upsamp_every": [3, 4, 6]
"lr_sh": "loglin(5e-3, 5e-2, 4)",
"lr_basis": "loglin(1e-6, 1e-3, 8)",
"upsamp_every": [25600, 38400, 51200],
"lambda_tv_sh": "loglin(1e-5, 1e-2, 6)"
},
"tasks": [{
"train_dir": "up{upsamp_every}_lrc5e-2-d{lr_sh_final:01.7f}",
"train_dir": "lrc{lr_sh}_lrb{lr_basis}_tvc{lambda_tv_sh}_ups{upsamp_every}",
"flags": [
"--n_epochs", "10",
"--init_reso", "256",
"--upsamp_every", "{upsamp_every}",
"--lr_sh_final", "{lr_sh_final}"
"--lr_sh", "{lr_sh}",
"--lr_basis", "{lr_basis}",
"--upsamp_every", "{upsamp_every}"
]
}]
}
24 changes: 1 addition & 23 deletions opt/util/ff_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,29 +29,7 @@
from .load_llff import load_llff_data
from typing import Union


def convert_to_ndc(origins, directions, ndc_coeffs, near: float = 1.0):
"""Convert a set of rays to NDC coordinates."""
# Shift ray origins to near plane, not sure if needed
t = (near - origins[Ellipsis, 2]) / directions[Ellipsis, 2]
origins = origins + t[Ellipsis, None] * directions

dx, dy, dz = directions.unbind(-1)
ox, oy, oz = origins.unbind(-1)

# Projection
o0 = ndc_coeffs[0] * (ox / oz)
o1 = ndc_coeffs[1] * (oy / oz)
o2 = 1 - 2 * near / oz

d0 = ndc_coeffs[0] * (dx / dz - ox / oz)
d1 = ndc_coeffs[1] * (dy / dz - oy / oz)
d2 = 2 * near / oz;

origins = torch.stack([o0, o1, o2], -1)
directions = torch.stack([d0, d1, d2], -1)
return origins, directions

from svox2.utils import convert_to_ndc

class LLFFDataset(Dataset):
def __init__(
Expand Down
1 change: 1 addition & 0 deletions svox2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .defs import *
from .svox2 import SparseGrid, Camera, Rays, RenderOptions
from .version import __version__
13 changes: 11 additions & 2 deletions svox2/csrc/include/data_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@

using torch::Tensor;

enum BasisType {
// For svox 1 compatibility
// BASIS_TYPE_RGBA = 0
BASIS_TYPE_SH = 1,
// BASIS_TYPE_SG = 2
// BASIS_TYPE_ASG = 3
BASIS_TYPE_3D_TEXTURE = 4,
BASIS_TYPE_MLP = 255,
};

struct SparseGridSpec {
Tensor density_data;
Tensor sh_data;
Expand All @@ -13,7 +23,7 @@ struct SparseGridSpec {
Tensor _scaling;

int basis_dim;
bool use_learned_basis;
uint8_t basis_type;
Tensor basis_data;

inline void check() {
Expand All @@ -32,7 +42,6 @@ struct SparseGridSpec {
TORCH_CHECK(density_data.ndimension() == 2);
TORCH_CHECK(sh_data.ndimension() == 2);
TORCH_CHECK(links.ndimension() == 3);
TORCH_CHECK(basis_data.ndimension() == 4);
}
};

Expand Down
4 changes: 2 additions & 2 deletions svox2/csrc/include/data_spec_packed.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ struct PackedSparseGridSpec {
density_data(spec.density_data.data_ptr<float>()),
sh_data(spec.sh_data.data_ptr<float>()),
links(spec.links.data_ptr<int32_t>()),
use_learned_basis(spec.use_learned_basis),
basis_type(spec.basis_type),
basis_data(spec.basis_data.data_ptr<float>()),
size{(int)spec.links.size(0),
(int)spec.links.size(1),
Expand All @@ -33,7 +33,7 @@ struct PackedSparseGridSpec {
float* __restrict__ density_data;
float* __restrict__ sh_data;
const int32_t* __restrict__ links;
const bool use_learned_basis;
const uint8_t basis_type;
float* __restrict__ basis_data;

const int size[3], stride_x;
Expand Down
Loading

0 comments on commit 878ac55

Please sign in to comment.