Skip to content

Commit

Permalink
Add SSIM
Browse files Browse the repository at this point in the history
  • Loading branch information
sxyu committed Nov 7, 2021
1 parent 620cc62 commit 05d0c14
Show file tree
Hide file tree
Showing 9 changed files with 254 additions and 56 deletions.
68 changes: 46 additions & 22 deletions opt/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from util.util import Timing, get_expon_lr_func, generate_dirs_equirect, viridis_cmap
from util import config_util

from warnings import warn
from torch.utils.tensorboard import SummaryWriter

from tqdm import tqdm
Expand All @@ -41,18 +42,20 @@
type=str,
default=
# "[[128, 128, 128], [256, 256, 256], [512, 512, 512], [768, 768, 768]]",
"[[256, 256, 256], [512, 512, 512], [768, 768, 768]]",
"[[128, 128, 128], [256, 256, 256], [512, 512, 512]]",
help='List of grid resolution (will be evaled as json);'
'resamples to the next one every upsamp_every iters, then ' +
'stays at the last one; ' +
'should be a list where each item is a list of 3 ints or an int')
group.add_argument('--upsamp_every', type=int, default=
2 * 12800,
# 3 * 12800,
help='upsample the grid every x iters')
group.add_argument('--upsample_density_factor', type=float, default=
1.0,
help='multiply the remaining density by this amount when upsampling')
group.add_argument('--init_iters', type=int, default=
0, #-12800,
help='do not upsample for first x iters')
group.add_argument('--upsample_density_add', type=float, default=
0.0,
help='add the remaining density by this amount when upsampling')

group.add_argument('--basis_type',
choices=['sh', '3d_texture', 'mlp'],
Expand All @@ -73,7 +76,7 @@


group = parser.add_argument_group("optimization")
group.add_argument('--n_iters', type=int, default=20 * 12800, help='total number of iters to optimize for')
group.add_argument('--n_iters', type=int, default=10 * 12800, help='total number of iters to optimize for')
group.add_argument('--batch_size', type=int, default=
5000,
#100000,
Expand All @@ -83,13 +86,14 @@

# TODO: make the lr higher near the end
group.add_argument('--sigma_optim', choices=['sgd', 'rmsprop'], default='rmsprop', help="Density optimizer")
group.add_argument('--lr_sigma', type=float, default=3e1,
group.add_argument('--lr_sigma', type=float, default=
3e1,
help='SGD/rmsprop lr for sigma')
group.add_argument('--lr_sigma_final', type=float, default=5e-2)
group.add_argument('--lr_sigma_decay_steps', type=int, default=250000)
group.add_argument('--lr_sigma_delay_steps', type=int, default=15000,
help="Reverse cosine steps (0 means disable)")
group.add_argument('--lr_sigma_delay_mult', type=float, default=1e-2)
group.add_argument('--lr_sigma_delay_mult', type=float, default=1e-2)#1e-4)#1e-4)


group.add_argument('--sh_optim', choices=['sgd', 'rmsprop'], default='rmsprop', help="SH optimizer")
Expand All @@ -98,7 +102,7 @@
help='SGD/rmsprop lr for SH')
group.add_argument('--lr_sh_final', type=float,
default=
5e-5
5e-6
)
group.add_argument('--lr_sh_decay_steps', type=int, default=250000)
group.add_argument('--lr_sh_delay_steps', type=int, default=0, help="Reverse cosine steps (0 means disable)")
Expand Down Expand Up @@ -141,8 +145,10 @@

group = parser.add_argument_group("misc experiments")
group.add_argument('--weight_thresh', type=float,
default=0.0005,
help='Resample (upsample to 512) weight threshold')
# default=0.0005,
# default=0.005 * 256,
default=0.0005 * 512,
help='Upsample weight threshold; will be divided by resulting z-resolution')

group.add_argument('--tune_mode', action='store_true', default=False,
help='hypertuning mode (do not save, for speed)')
Expand All @@ -166,6 +172,16 @@
group.add_argument('--lambda_l2_sh', type=float, default=0.0)#1e-4)
# End Foreground TV

group.add_argument('--lambda_sparsity', type=float, default=
0.0,
# 1e-11,
help="Weight for sparsity loss as in SNeRG/PlenOctrees " +
"(but applied on the ray)")
group.add_argument('--lambda_beta', type=float, default=
0.0,
# 1e-5,
help="Weight for beta distribution sparsity loss as in neural volumes")


# Background TV
group.add_argument('--lambda_tv_background_sigma', type=float, default=1e-5)
Expand All @@ -174,8 +190,10 @@
group.add_argument('--tv_background_sparsity', type=float, default=0.01)
# End Background TV

# Basis TV
group.add_argument('--lambda_tv_basis', type=float, default=0.0,
help='Learned basis total variation loss')
# End Basis TV

group.add_argument('--weight_decay_sigma', type=float, default=1.0)
group.add_argument('--weight_decay_sh', type=float, default=1.0)
Expand Down Expand Up @@ -213,7 +231,7 @@
**config_util.build_data_options(args))

if args.background_nlayers > 0 and not dset.should_use_background:
print('Using a background model for dataset type ', type(dset), 'which typically does not use background')
warn('Using a background model for dataset type ' + str(type(dset)) + ' which typically does not use background')

dset_test = datasets[args.dataset_type](
args.data_dir, split="test", **config_util.build_data_options(args))
Expand Down Expand Up @@ -297,7 +315,7 @@
lr_sh_factor = 1.0
lr_basis_factor = 1.0

last_upsamp_step = 0
last_upsamp_step = args.init_iters

epoch_id = -1
while True:
Expand Down Expand Up @@ -354,7 +372,8 @@ def eval_step():
mse_num : float = all_mses.mean().item()
psnr = -10.0 * math.log10(mse_num)
if math.isnan(psnr):
print('NAN PSNR', i, img_id)
print('NAN PSNR', i, img_id, mse_num)
assert False
stats_test['mse'] += mse_num
stats_test['psnr'] += psnr
n_images_gen += 1
Expand Down Expand Up @@ -419,7 +438,9 @@ def train_step():
rays = svox2.Rays(batch_origins, batch_dirs)

# with Timing("volrend_fused"):
rgb_pred = grid.volume_render_fused(rays, rgb_gt)
rgb_pred = grid.volume_render_fused(rays, rgb_gt,
beta_loss=args.lambda_beta,
sparsity_loss=args.lambda_sparsity)

# with Timing("loss_comp"):
mse = F.mse_loss(rgb_gt, rgb_pred)
Expand Down Expand Up @@ -457,29 +478,31 @@ def train_step():
grid.sh_data.data *= args.weight_decay_sigma
if args.weight_decay_sigma < 1.0:
grid.density_data.data *= args.weight_decay_sh
torch.cuda.synchronize() # FIXME remove
# torch.cuda.synchronize() # FIXME remove

# Apply TV/Sparsity regularizers
if args.lambda_tv > 0.0:
# with Timing("tv_inpl"):
grid.inplace_tv_grad(grid.density_data.grad,
scaling=args.lambda_tv,
sparse_frac=args.tv_sparsity,
logalpha=args.tv_logalpha,
ndc_coeffs=dset.ndc_coeffs)
torch.cuda.synchronize() # FIXME remove
# torch.cuda.synchronize() # FIXME remove
if args.lambda_tv_sh > 0.0:
# with Timing("tv_color_inpl"):
grid.inplace_tv_color_grad(grid.sh_data.grad,
scaling=args.lambda_tv_sh,
sparse_frac=args.tv_sh_sparsity,
ndc_coeffs=dset.ndc_coeffs)
torch.cuda.synchronize() # FIXME remove
# torch.cuda.synchronize() # FIXME remove
if args.lambda_tv_lumisphere > 0.0:
grid.inplace_tv_lumisphere_grad(grid.sh_data.grad,
scaling=args.lambda_tv_lumisphere,
dir_factor=args.tv_lumisphere_dir_factor,
sparse_frac=args.tv_lumisphere_sparsity,
ndc_coeffs=dset.ndc_coeffs)
torch.cuda.synchronize() # FIXME remove
# torch.cuda.synchronize() # FIXME remove
if args.lambda_l2_sh > 0.0:
grid.inplace_l2_color_grad(grid.sh_data.grad,
scaling=args.lambda_l2_sh)
Expand Down Expand Up @@ -524,14 +547,15 @@ def train_step():
print('* Upsampling from', reso_list[reso_id], 'to', reso_list[reso_id + 1])
reso_id += 1
use_sparsify = True
z_reso = reso_list[reso_id] if isinstance(reso_list[reso_id], int) else reso_list[reso_id][2]
grid.resample(reso=reso_list[reso_id],
# sigma_thresh=args.sigma_thresh if use_sparsify else 0.0,
weight_thresh=args.weight_thresh if use_sparsify else 0.0,
weight_thresh=args.weight_thresh / z_reso if use_sparsify else 0.0,
dilate=2, #use_sparsify,
cameras=resample_cameras)

if args.upsample_density_factor:
grid.density_data.data[:] *= args.upsample_density_factor
if args.upsample_density_add:
grid.density_data.data[:] += args.upsample_density_add
if args.lr_sh_upscale_factor > 1:
lr_sh_factor *= args.lr_sh_upscale_factor
print('Increased lr to (sigma:)', args.lr_sigma, '(sh:)', args.lr_sh)
Expand Down
26 changes: 21 additions & 5 deletions opt/render_imgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os
from os import path
from util.dataset import datasets
from util.util import Timing
from util.util import Timing, compute_ssim
from util import config_util

import imageio
Expand All @@ -35,6 +35,10 @@
action='store_true',
default=False,
help="Do not render background (if using BG model)")
parser.add_argument('--blackbg',
action='store_true',
default=False,
help="Force a black BG (behind BG model) color; useful for debugging 'clouds'")
args = parser.parse_args()
config_util.maybe_merge_config_file(args)
device = 'cuda:0'
Expand All @@ -60,20 +64,27 @@
grid.density_data.data[:] = 0.0
render_dir += '_nofg'

print('Writing to', render_dir)
os.makedirs(render_dir, exist_ok=True)

grid.opt.step_size = args.step_size
grid.opt.sigma_thresh = args.sigma_thresh
grid.opt.stop_thresh = args.stop_thresh
grid.opt.background_brightness = 1.0
grid.opt.backend = args.renderer_backend

if args.blackbg:
print('Using black bg')
render_dir += '_blackbg'
grid.opt.background_brightness = 0.0

print('Writing to', render_dir)
os.makedirs(render_dir, exist_ok=True)

with torch.no_grad():
im_size = dset.h * dset.w
n_images = dset.render_c2w.size(0) if args.render_path else dset.n_images
img_eval_interval = max(n_images // args.n_eval, 1)
avg_psnr = 0.0
avg_ssim = 0.0
n_images_gen = 0
cam = svox2.Camera(torch.tensor(0), dset.intrins.fx, dset.intrins.fy,
dset.intrins.cx, dset.intrins.cy,
Expand All @@ -89,8 +100,10 @@
mse = (im - im_gt) ** 2
mse_num : float = mse.mean().item()
psnr = -10.0 * math.log10(mse_num)
ssim = compute_ssim(im_gt, im)
avg_psnr += psnr
print(img_id, 'PSNR', psnr)
avg_ssim += ssim
print(img_id, 'PSNR', psnr, 'SSIM', ssim)
# all_rgbs = []
# all_mses = []
# for batch_begin in range(0, im_size, args.eval_batch_size):
Expand All @@ -113,6 +126,9 @@
im = None
n_images_gen += 1
avg_psnr /= n_images_gen
print('average PSNR', avg_psnr)
avg_ssim /= n_images_gen
print('average PSNR', avg_psnr, 'SSIM', avg_ssim)
with open(path.join(render_dir, 'psnr.txt'), 'w') as f:
f.write(str(avg_psnr))
with open(path.join(render_dir, 'ssim.txt'), 'w') as f:
f.write(str(avg_ssim))
20 changes: 9 additions & 11 deletions opt/tasks/sanity.json
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
{
"data_root": "/home/sxyu/data/nerf_synthetic/lego",
"train_root": "/home/sxyu/proj/svox2/opt/ckpt_tune/mlp_sweep",
"data_root": "/home/sxyu/data/nerf_synthetic/ship",
"train_root": "/home/sxyu/proj/svox2/opt/ckpt_tune/ship_sweep",
"variables": {
"lr_sh": "loglin(5e-3, 5e-2, 4)",
"lr_basis": "loglin(1e-6, 1e-3, 8)",
"upsamp_every": [25600, 38400, 51200],
"lambda_tv_sh": "loglin(1e-5, 1e-2, 6)"
"lr_sh_final": "loglin(5e-7, 5e-2, 10)",
"lr_sigma": "loglin(5e0, 2e2, 4)",
"lr_sigma_delay_steps": [25000, 40000, 55000]
},
"tasks": [{
"train_dir": "lrc{lr_sh}_lrb{lr_basis}_tvc{lambda_tv_sh}_ups{upsamp_every}",
"train_dir": "lrcf{lr_sh_final}_lrs{lr_sigma}_del{lr_sigma_delay_steps}",
"flags": [
"--upsamp_every", "{upsamp_every}",
"--lr_sh", "{lr_sh}",
"--lr_basis", "{lr_basis}",
"--upsamp_every", "{upsamp_every}"
"--lr_sh_final", "{lr_sh_final}",
"--lr_sigma", "{lr_sigma}",
"--lr_sigma_delay_steps", "{lr_sigma_delay_steps}"
]
}]
}
Loading

0 comments on commit 05d0c14

Please sign in to comment.