Skip to content

Commit

Permalink
Print averages in autotune; added a spline interpolation util
Browse files Browse the repository at this point in the history
  • Loading branch information
sxyu committed Nov 14, 2021
1 parent 3df92a4 commit f481fa5
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 5 deletions.
13 changes: 13 additions & 0 deletions opt/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,9 @@ def recursive_replace(data, variables):
with open(leaderboard_path, 'w') as leaderboard_file:
lines = [f'dir\tPSNR\tSSIM\tLPIPS\n']
all_tasks = sorted(all_tasks, key=lambda task:task['train_dir'])
all_psnr = []
all_ssim = []
all_lpips = []
for task in all_tasks:
train_dir = task['train_dir']
psnr_file_path = path.join(train_dir, 'test_renders', 'psnr.txt')
Expand All @@ -260,23 +263,33 @@ def recursive_replace(data, variables):
if path.isfile(psnr_file_path):
with open(psnr_file_path, 'r') as f:
psnr = float(f.read())
all_psnr.append(psnr)
psnr_txt = f'{psnr:.10f}'
else:
psnr_txt = 'ERR'
if path.isfile(ssim_file_path):
with open(ssim_file_path, 'r') as f:
ssim = float(f.read())
all_ssim.append(ssim)
ssim_txt = f'{ssim:.10f}'
else:
ssim_txt = 'ERR'
if path.isfile(lpips_file_path):
with open(lpips_file_path, 'r') as f:
lpips = float(f.read())
all_lpips.append(lpips)
lpips_txt = f'{lpips:.10f}'
else:
lpips_txt = 'ERR'
line = f'{path.basename(train_dir.rstrip("/"))}\t{psnr_txt}\t{ssim_txt}\t{lpips_txt}\n'
lines.append(line)
lines.append('---------\n')
if len(all_psnr):
lines.append('Average PSNR: ' + str(sum(all_psnr) / len(all_psnr)) + '\n')
if len(all_ssim):
lines.append('Average SSIM: ' + str(sum(all_ssim) / len(all_ssim)) + '\n')
if len(all_lpips):
lines.append('Average LPIPS: ' + str(sum(all_lpips) / len(all_lpips)) + '\n')
leaderboard_file.writelines(lines)

else:
Expand Down
3 changes: 2 additions & 1 deletion opt/calc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@
psnr = -10.0 * math.log10(mse_num)
ssim = compute_ssim(im_gt, im).item()
lpips_i = lpips_vgg(im_gt.permute([2, 0, 1]).cuda().contiguous(),
im.permute([2, 0, 1]).cuda().contiguous(), normalize=True).item()
im.permute([2, 0, 1]).cuda().contiguous(),
normalize=True).item()

print(i, 'of', len(im_files), '; PSNR', psnr, 'SSIM', ssim, 'LPIPS', lpips_i)
avg_psnr += psnr
Expand Down
4 changes: 2 additions & 2 deletions opt/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@
default="weight",
help='Upsample threshold type')
group.add_argument('--weight_thresh', type=float,
# default=0.0005 * 512,
default=0.025 * 512,
default=0.0005 * 512,
# default=0.025 * 512,
help='Upsample weight threshold; will be divided by resulting z-resolution')
group.add_argument('--density_thresh', type=float,
default=5.0,
Expand Down
2 changes: 1 addition & 1 deletion opt/tasks/eval_tnt.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"eval": true,
"data_root": "/home/sxyu/data/TanksAndTempleBG",
"train_root": "/home/sxyu/proj/svox2/opt/ckpt_auto/tnt_spars",
"train_root": "/home/sxyu/proj/svox2/opt/ckpt_auto/tnt_spars_hdbg_var",
"config": "configs/tnt.json",
"tasks": [{
"train_dir": "Train",
Expand Down
3 changes: 2 additions & 1 deletion opt/util/nsvf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
scale : Optional[float] = 1.0, # Image scaling (on load)
permutation: bool = True,
white_bkgd: bool = True,
normalize_by_bbox: bool = True,
normalize_by_bbox: bool = False,
data_bbox_scale : float = 1.1,
**kwargs
):
Expand Down Expand Up @@ -126,6 +126,7 @@ def look_for_dir(cands, required=True):
self.c2w_f64 = torch.stack(all_c2w)

if normalize_by_bbox:
# Not used, but could be helpful
bbox_path = path.join(root, "bbox.txt")
if path.exists(bbox_path):
bbox_data = np.loadtxt(bbox_path)
Expand Down
40 changes: 40 additions & 0 deletions opt/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from dataclasses import dataclass
import numpy as np
import cv2
from scipy.spatial.transform import Rotation
from scipy.interpolate import CubicSpline
from matplotlib import pyplot as plt
from warnings import warn

Expand Down Expand Up @@ -308,3 +310,41 @@ def generate_rays(w, h, focal, camtoworlds, equirect=False):
origins=origins, directions=directions, viewdirs=viewdirs
)
return rays


def jiggle_and_interp_poses(poses : torch.Tensor,
n_inter: int,
noise_std : float=0.0):
"""
For generating a novel trajectory close to known trajectory
:param poses: torch.Tensor (B, 4, 4)
:param n_inter: int, number of views to interpolate in total
:param noise_std: float, default 0
"""
n_views_in = poses.size(0)
poses_np = poses.cpu().numpy().copy()
rot = Rotation.from_matrix(poses_np[:, :3, :3])
trans = poses_np[:, :3, 3]
trans += np.random.randn(*trans.shape) * noise_std
pose_quat = rot.as_quat()

t_in = np.arange(n_views_in, dtype=np.float32)
t_out = np.linspace(t_in[0], t_in[-1], n_inter, dtype=np.float32)

q_new = CubicSpline(t_in, pose_quat)
q_new : np.ndarray = q_new(t_out)
q_new = q_new / np.linalg.norm(q_new, axis=-1)[..., None]

t_new = CubicSpline(t_in, trans)
t_new = t_new(t_out)

rot_new = Rotation.from_quat(q_new)
R_new = rot_new.as_matrix()

Rt_new = np.concatenate([R_new, t_new[..., None]], axis=-1)
bottom = np.array([[0.0, 0.0, 0.0, 1.0]], dtype=np.float32)
bottom = bottom[None].repeat(Rt_new.shape[0], 0)
Rt_new = np.concatenate([Rt_new, bottom], axis=-2)
Rt_new = torch.from_numpy(Rt_new).to(device=poses.device, dtype=poses.dtype)
return Rt_new

0 comments on commit f481fa5

Please sign in to comment.