forked from sxyu/svox2
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
349 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
{ | ||
"reso": "[[128, 128, 128], [256, 256, 256], [512, 512, 512], [640, 640, 640]]", | ||
"seq_id": 5075, | ||
"n_iters": 102400, | ||
"background_nlayers": 64, | ||
"background_reso": 1024, | ||
"upsamp_every": 25600, | ||
"near_clip": 0.35, | ||
"lr_sigma": 3e1, | ||
"lr_sh": 1e-2, | ||
"lr_sigma_delay_steps": 0, | ||
"lr_fg_begin_step": 1000, | ||
"thresh_type": "weight", | ||
"weight_thresh": 1.28, | ||
"lambda_tv": 5e-5, | ||
"lambda_tv_sh": 5e-3, | ||
"lambda_tv_background_sigma": 1e-3, | ||
"lambda_tv_background_color": 1e-3, | ||
"lambda_beta": 1e-5, | ||
"lambda_sparsity": 1e-11, | ||
"background_brightness": 0.5, | ||
"tv_early_only": 0, | ||
"tv_decay": 0.5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,247 @@ | ||
# Copyright 2021 Alex Yu | ||
# Render 360 circle path | ||
|
||
import torch | ||
import svox2 | ||
import svox2.utils | ||
import math | ||
import argparse | ||
import numpy as np | ||
import os | ||
from os import path | ||
from util.dataset import datasets | ||
from util.util import Timing, compute_ssim, viridis_cmap, pose_spherical | ||
from util import config_util | ||
|
||
import imageio | ||
import cv2 | ||
from tqdm import tqdm | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('ckpt', type=str) | ||
|
||
config_util.define_common_args(parser) | ||
|
||
parser.add_argument('--n_eval', '-n', type=int, default=100000, help='images to evaluate (equal interval), at most evals every image') | ||
parser.add_argument('--traj_type', | ||
choices=['spiral', 'circle'], | ||
default='spiral', | ||
help="Render a spiral (doubles length, using 2 elevations), or just a cirle") | ||
parser.add_argument('--fps', | ||
type=int, | ||
default=30, | ||
help="FPS of video") | ||
parser.add_argument( | ||
"--width", "-W", type=float, default=None, help="Rendering image width (only if not --traj)" | ||
) | ||
parser.add_argument( | ||
"--height", "-H", type=float, default=None, help="Rendering image height (only if not --traj)" | ||
) | ||
parser.add_argument( | ||
"--num_views", "-N", type=int, default=600, | ||
help="Number of frames to render" | ||
) | ||
|
||
# Path adjustment | ||
parser.add_argument( | ||
"--offset", type=str, default="0,0,0", help="Center point to rotate around (only if not --traj)" | ||
) | ||
parser.add_argument("--radius", type=float, default=0.85, help="Radius of orbit (only if not --traj)") | ||
parser.add_argument( | ||
"--elevation", | ||
type=float, | ||
default=-45.0, | ||
help="Elevation of orbit in deg, negative is above", | ||
) | ||
parser.add_argument( | ||
"--elevation2", | ||
type=float, | ||
default=-12.0, | ||
help="Max elevation, only for spiral", | ||
) | ||
parser.add_argument( | ||
"--vec_up", | ||
type=str, | ||
default=None, | ||
help="up axis for camera views (only if not --traj);" | ||
"3 floats separated by ','; if not given automatically determined", | ||
) | ||
parser.add_argument( | ||
"--vert_shift", | ||
type=float, | ||
default=0.0, | ||
help="vertical shift by up axis" | ||
) | ||
|
||
# Camera adjustment | ||
parser.add_argument('--crop', | ||
type=float, | ||
default=1.0, | ||
help="Crop (0, 1], 1.0 = full image") | ||
|
||
# Foreground/background only | ||
parser.add_argument('--nofg', | ||
action='store_true', | ||
default=False, | ||
help="Do not render foreground (if using BG model)") | ||
parser.add_argument('--nobg', | ||
action='store_true', | ||
default=False, | ||
help="Do not render background (if using BG model)") | ||
|
||
# Random debugging features | ||
parser.add_argument('--blackbg', | ||
action='store_true', | ||
default=False, | ||
help="Force a black BG (behind BG model) color; useful for debugging 'clouds'") | ||
|
||
args = parser.parse_args() | ||
config_util.maybe_merge_config_file(args, allow_invalid=True) | ||
device = 'cuda:0' | ||
|
||
|
||
dset = datasets[args.dataset_type](args.data_dir, split="test", | ||
**config_util.build_data_options(args)) | ||
|
||
if args.vec_up is None: | ||
up_rot = dset.c2w[:, :3, :3].cpu().numpy() | ||
ups = np.matmul(up_rot, np.array([0, -1.0, 0])[None, :, None])[..., 0] | ||
args.vec_up = np.mean(ups, axis=0) | ||
args.vec_up /= np.linalg.norm(args.vec_up) | ||
print(' Auto vec_up', args.vec_up) | ||
else: | ||
args.vec_up = np.array(list(map(float, args.vec_up.split(",")))) | ||
|
||
|
||
args.offset = np.array(list(map(float, args.offset.split(",")))) | ||
if args.traj_type == 'spiral': | ||
angles = np.linspace(-180, 180, args.num_views + 1)[:-1] | ||
elevations = np.linspace(args.elevation, args.elevation2, args.num_views) | ||
c2ws = [ | ||
pose_spherical( | ||
angle, | ||
ele, | ||
args.radius, | ||
args.offset, | ||
vec_up=args.vec_up, | ||
) | ||
for ele, angle in zip(elevations, angles) | ||
] | ||
c2ws += [ | ||
pose_spherical( | ||
angle, | ||
ele, | ||
args.radius, | ||
args.offset, | ||
vec_up=args.vec_up, | ||
) | ||
for ele, angle in zip(reversed(elevations), angles) | ||
] | ||
else : | ||
c2ws = [ | ||
pose_spherical( | ||
angle, | ||
args.elevation, | ||
args.radius, | ||
args.offset, | ||
vec_up=args.vec_up, | ||
) | ||
for angle in np.linspace(-180, 180, args.num_views + 1)[:-1] | ||
] | ||
c2ws = np.stack(c2ws, axis=0) | ||
if args.vert_shift != 0.0: | ||
c2ws[:, :3, 3] += np.array(args.vec_up) * args.vert_shift | ||
c2ws = torch.from_numpy(c2ws).to(device=device) | ||
|
||
if not path.isfile(args.ckpt): | ||
args.ckpt = path.join(args.ckpt, 'ckpt.npz') | ||
|
||
render_out_path = path.join(path.dirname(args.ckpt), 'circle_renders') | ||
|
||
# Handle various image transforms | ||
if args.crop != 1.0: | ||
render_out_path += f'_crop{args.crop}' | ||
if args.vert_shift != 0.0: | ||
render_out_path += f'_vshift{args.vert_shift}' | ||
|
||
grid = svox2.SparseGrid.load(args.ckpt, device=device) | ||
print(grid.center, grid.radius) | ||
|
||
# DEBUG | ||
# grid.background_data.data[:, 32:, -1] = 0.0 | ||
# render_out_path += '_front' | ||
|
||
if grid.use_background: | ||
if args.nobg: | ||
grid.background_data.data[..., -1] = 0.0 | ||
render_out_path += '_nobg' | ||
if args.nofg: | ||
grid.density_data.data[:] = 0.0 | ||
# grid.sh_data.data[..., 0] = 1.0 / svox2.utils.SH_C0 | ||
# grid.sh_data.data[..., 9] = 1.0 / svox2.utils.SH_C0 | ||
# grid.sh_data.data[..., 18] = 1.0 / svox2.utils.SH_C0 | ||
render_out_path += '_nofg' | ||
|
||
# # DEBUG | ||
# grid.background_data.data[..., -1] = 100.0 | ||
# a1 = torch.linspace(0, 1, grid.background_data.size(0) // 2, dtype=torch.float32, device=device)[:, None] | ||
# a2 = torch.linspace(1, 0, (grid.background_data.size(0) - 1) // 2 + 1, dtype=torch.float32, device=device)[:, None] | ||
# a = torch.cat([a1, a2], dim=0) | ||
# c = torch.stack([a, 1-a, torch.zeros_like(a)], dim=-1) | ||
# grid.background_data.data[..., :-1] = c | ||
# render_out_path += "_gradient" | ||
|
||
config_util.setup_render_opts(grid.opt, args) | ||
|
||
if args.blackbg: | ||
print('Forcing black bg') | ||
render_out_path += '_blackbg' | ||
grid.opt.background_brightness = 0.0 | ||
|
||
render_out_path += '.mp4' | ||
print('Writing to', render_out_path) | ||
|
||
# NOTE: no_grad enables the fast image-level rendering kernel for cuvol backend only | ||
# other backends will manually generate rays per frame (slow) | ||
with torch.no_grad(): | ||
n_images = c2ws.size(0) | ||
img_eval_interval = max(n_images // args.n_eval, 1) | ||
avg_psnr = 0.0 | ||
avg_ssim = 0.0 | ||
avg_lpips = 0.0 | ||
n_images_gen = 0 | ||
frames = [] | ||
# if args.near_clip >= 0.0: | ||
grid.opt.near_clip = 0.0 #args.near_clip | ||
if args.width is None: | ||
args.width = dset.get_image_size(0)[1] | ||
if args.height is None: | ||
args.height = dset.get_image_size(0)[0] | ||
|
||
for img_id in tqdm(range(0, n_images, img_eval_interval)): | ||
dset_h, dset_w = args.height, args.width | ||
im_size = dset_h * dset_w | ||
w = dset_w if args.crop == 1.0 else int(dset_w * args.crop) | ||
h = dset_h if args.crop == 1.0 else int(dset_h * args.crop) | ||
|
||
cam = svox2.Camera(c2ws[img_id], | ||
dset.intrins.get('fx', 0), | ||
dset.intrins.get('fy', 0), | ||
w * 0.5, | ||
h * 0.5, | ||
w, h, | ||
ndc_coeffs=(-1.0, -1.0)) | ||
torch.cuda.synchronize() | ||
im = grid.volume_render_image(cam, use_kernel=True) | ||
torch.cuda.synchronize() | ||
im.clamp_(0.0, 1.0) | ||
|
||
im = im.cpu().numpy() | ||
im = (im * 255).astype(np.uint8) | ||
frames.append(im) | ||
im = None | ||
n_images_gen += 1 | ||
if len(frames): | ||
vid_path = render_out_path | ||
imageio.mimwrite(vid_path, frames, fps=args.fps, macro_block_size=8) # pip install imageio-ffmpeg | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from typing import Union, Optional, List | ||
from .util import select_or_shuffle_rays, Rays, Intrin | ||
|
||
class DatasetBase: | ||
split: str | ||
permutation: bool | ||
epoch_size: Optional[int] | ||
n_images: int | ||
h_full: int | ||
w_full: int | ||
intrins_full: Intrin | ||
c2w: torch.Tensor # C2W OpenCV poses | ||
gt: Union[torch.Tensor, List[torch.Tensor]] # RGB images | ||
device : Union[str, torch.device] | ||
|
||
def __init__(self): | ||
self.ndc_coeffs = (-1, -1) | ||
self.use_sphere_bound = False | ||
self.should_use_background = True # a hint | ||
self.use_sphere_bound = True | ||
self.scene_center = [0.0, 0.0, 0.0] | ||
self.scene_radius = [1.0, 1.0, 1.0] | ||
self.permutation = False | ||
|
||
def shuffle_rays(self): | ||
""" | ||
Shuffle all rays | ||
""" | ||
if self.split == "train": | ||
del self.rays | ||
self.rays = select_or_shuffle_rays(self.rays_init, self.permutation, | ||
self.epoch_size, self.device) | ||
|
||
def gen_rays(self, factor=1): | ||
print(" Generating rays, scaling factor", factor) | ||
# Generate rays | ||
self.factor = factor | ||
self.h = self.h_full // factor | ||
self.w = self.w_full // factor | ||
true_factor = self.h_full / self.h | ||
self.intrins = self.intrins_full.scale(1.0 / true_factor) | ||
yy, xx = torch.meshgrid( | ||
torch.arange(self.h, dtype=torch.float32) + 0.5, | ||
torch.arange(self.w, dtype=torch.float32) + 0.5, | ||
) | ||
xx = (xx - self.intrins.cx) / self.intrins.fx | ||
yy = (yy - self.intrins.cy) / self.intrins.fy | ||
zz = torch.ones_like(xx) | ||
dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention | ||
dirs /= torch.norm(dirs, dim=-1, keepdim=True) | ||
dirs = dirs.reshape(1, -1, 3, 1) | ||
del xx, yy, zz | ||
dirs = (self.c2w[:, None, :3, :3] @ dirs)[..., 0] | ||
|
||
if factor != 1: | ||
gt = F.interpolate( | ||
self.gt.permute([0, 3, 1, 2]), size=(self.h, self.w), mode="area" | ||
).permute([0, 2, 3, 1]) | ||
gt = gt.reshape(self.n_images, -1, 3) | ||
else: | ||
gt = self.gt.reshape(self.n_images, -1, 3) | ||
origins = self.c2w[:, None, :3, 3].expand(-1, self.h * self.w, -1).contiguous() | ||
if self.split == "train": | ||
origins = origins.view(-1, 3) | ||
dirs = dirs.view(-1, 3) | ||
gt = gt.reshape(-1, 3) | ||
|
||
self.rays_init = Rays(origins=origins, dirs=dirs, gt=gt) | ||
self.rays = self.rays_init | ||
|
||
def get_image_size(self, i : int): | ||
# H, W | ||
if hasattr(self, 'image_size'): | ||
return tuple(self.image_size[i]) | ||
else: | ||
return self.h, self.w |