Skip to content

Commit

Permalink
Added missing files
Browse files Browse the repository at this point in the history
  • Loading branch information
sxyu committed Dec 9, 2021
1 parent c42c46c commit d296bc9
Show file tree
Hide file tree
Showing 3 changed files with 349 additions and 0 deletions.
24 changes: 24 additions & 0 deletions opt/configs/co3d.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"reso": "[[128, 128, 128], [256, 256, 256], [512, 512, 512], [640, 640, 640]]",
"seq_id": 5075,
"n_iters": 102400,
"background_nlayers": 64,
"background_reso": 1024,
"upsamp_every": 25600,
"near_clip": 0.35,
"lr_sigma": 3e1,
"lr_sh": 1e-2,
"lr_sigma_delay_steps": 0,
"lr_fg_begin_step": 1000,
"thresh_type": "weight",
"weight_thresh": 1.28,
"lambda_tv": 5e-5,
"lambda_tv_sh": 5e-3,
"lambda_tv_background_sigma": 1e-3,
"lambda_tv_background_color": 1e-3,
"lambda_beta": 1e-5,
"lambda_sparsity": 1e-11,
"background_brightness": 0.5,
"tv_early_only": 0,
"tv_decay": 0.5
}
247 changes: 247 additions & 0 deletions opt/render_imgs_circle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
# Copyright 2021 Alex Yu
# Render 360 circle path

import torch
import svox2
import svox2.utils
import math
import argparse
import numpy as np
import os
from os import path
from util.dataset import datasets
from util.util import Timing, compute_ssim, viridis_cmap, pose_spherical
from util import config_util

import imageio
import cv2
from tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument('ckpt', type=str)

config_util.define_common_args(parser)

parser.add_argument('--n_eval', '-n', type=int, default=100000, help='images to evaluate (equal interval), at most evals every image')
parser.add_argument('--traj_type',
choices=['spiral', 'circle'],
default='spiral',
help="Render a spiral (doubles length, using 2 elevations), or just a cirle")
parser.add_argument('--fps',
type=int,
default=30,
help="FPS of video")
parser.add_argument(
"--width", "-W", type=float, default=None, help="Rendering image width (only if not --traj)"
)
parser.add_argument(
"--height", "-H", type=float, default=None, help="Rendering image height (only if not --traj)"
)
parser.add_argument(
"--num_views", "-N", type=int, default=600,
help="Number of frames to render"
)

# Path adjustment
parser.add_argument(
"--offset", type=str, default="0,0,0", help="Center point to rotate around (only if not --traj)"
)
parser.add_argument("--radius", type=float, default=0.85, help="Radius of orbit (only if not --traj)")
parser.add_argument(
"--elevation",
type=float,
default=-45.0,
help="Elevation of orbit in deg, negative is above",
)
parser.add_argument(
"--elevation2",
type=float,
default=-12.0,
help="Max elevation, only for spiral",
)
parser.add_argument(
"--vec_up",
type=str,
default=None,
help="up axis for camera views (only if not --traj);"
"3 floats separated by ','; if not given automatically determined",
)
parser.add_argument(
"--vert_shift",
type=float,
default=0.0,
help="vertical shift by up axis"
)

# Camera adjustment
parser.add_argument('--crop',
type=float,
default=1.0,
help="Crop (0, 1], 1.0 = full image")

# Foreground/background only
parser.add_argument('--nofg',
action='store_true',
default=False,
help="Do not render foreground (if using BG model)")
parser.add_argument('--nobg',
action='store_true',
default=False,
help="Do not render background (if using BG model)")

# Random debugging features
parser.add_argument('--blackbg',
action='store_true',
default=False,
help="Force a black BG (behind BG model) color; useful for debugging 'clouds'")

args = parser.parse_args()
config_util.maybe_merge_config_file(args, allow_invalid=True)
device = 'cuda:0'


dset = datasets[args.dataset_type](args.data_dir, split="test",
**config_util.build_data_options(args))

if args.vec_up is None:
up_rot = dset.c2w[:, :3, :3].cpu().numpy()
ups = np.matmul(up_rot, np.array([0, -1.0, 0])[None, :, None])[..., 0]
args.vec_up = np.mean(ups, axis=0)
args.vec_up /= np.linalg.norm(args.vec_up)
print(' Auto vec_up', args.vec_up)
else:
args.vec_up = np.array(list(map(float, args.vec_up.split(","))))


args.offset = np.array(list(map(float, args.offset.split(","))))
if args.traj_type == 'spiral':
angles = np.linspace(-180, 180, args.num_views + 1)[:-1]
elevations = np.linspace(args.elevation, args.elevation2, args.num_views)
c2ws = [
pose_spherical(
angle,
ele,
args.radius,
args.offset,
vec_up=args.vec_up,
)
for ele, angle in zip(elevations, angles)
]
c2ws += [
pose_spherical(
angle,
ele,
args.radius,
args.offset,
vec_up=args.vec_up,
)
for ele, angle in zip(reversed(elevations), angles)
]
else :
c2ws = [
pose_spherical(
angle,
args.elevation,
args.radius,
args.offset,
vec_up=args.vec_up,
)
for angle in np.linspace(-180, 180, args.num_views + 1)[:-1]
]
c2ws = np.stack(c2ws, axis=0)
if args.vert_shift != 0.0:
c2ws[:, :3, 3] += np.array(args.vec_up) * args.vert_shift
c2ws = torch.from_numpy(c2ws).to(device=device)

if not path.isfile(args.ckpt):
args.ckpt = path.join(args.ckpt, 'ckpt.npz')

render_out_path = path.join(path.dirname(args.ckpt), 'circle_renders')

# Handle various image transforms
if args.crop != 1.0:
render_out_path += f'_crop{args.crop}'
if args.vert_shift != 0.0:
render_out_path += f'_vshift{args.vert_shift}'

grid = svox2.SparseGrid.load(args.ckpt, device=device)
print(grid.center, grid.radius)

# DEBUG
# grid.background_data.data[:, 32:, -1] = 0.0
# render_out_path += '_front'

if grid.use_background:
if args.nobg:
grid.background_data.data[..., -1] = 0.0
render_out_path += '_nobg'
if args.nofg:
grid.density_data.data[:] = 0.0
# grid.sh_data.data[..., 0] = 1.0 / svox2.utils.SH_C0
# grid.sh_data.data[..., 9] = 1.0 / svox2.utils.SH_C0
# grid.sh_data.data[..., 18] = 1.0 / svox2.utils.SH_C0
render_out_path += '_nofg'

# # DEBUG
# grid.background_data.data[..., -1] = 100.0
# a1 = torch.linspace(0, 1, grid.background_data.size(0) // 2, dtype=torch.float32, device=device)[:, None]
# a2 = torch.linspace(1, 0, (grid.background_data.size(0) - 1) // 2 + 1, dtype=torch.float32, device=device)[:, None]
# a = torch.cat([a1, a2], dim=0)
# c = torch.stack([a, 1-a, torch.zeros_like(a)], dim=-1)
# grid.background_data.data[..., :-1] = c
# render_out_path += "_gradient"

config_util.setup_render_opts(grid.opt, args)

if args.blackbg:
print('Forcing black bg')
render_out_path += '_blackbg'
grid.opt.background_brightness = 0.0

render_out_path += '.mp4'
print('Writing to', render_out_path)

# NOTE: no_grad enables the fast image-level rendering kernel for cuvol backend only
# other backends will manually generate rays per frame (slow)
with torch.no_grad():
n_images = c2ws.size(0)
img_eval_interval = max(n_images // args.n_eval, 1)
avg_psnr = 0.0
avg_ssim = 0.0
avg_lpips = 0.0
n_images_gen = 0
frames = []
# if args.near_clip >= 0.0:
grid.opt.near_clip = 0.0 #args.near_clip
if args.width is None:
args.width = dset.get_image_size(0)[1]
if args.height is None:
args.height = dset.get_image_size(0)[0]

for img_id in tqdm(range(0, n_images, img_eval_interval)):
dset_h, dset_w = args.height, args.width
im_size = dset_h * dset_w
w = dset_w if args.crop == 1.0 else int(dset_w * args.crop)
h = dset_h if args.crop == 1.0 else int(dset_h * args.crop)

cam = svox2.Camera(c2ws[img_id],
dset.intrins.get('fx', 0),
dset.intrins.get('fy', 0),
w * 0.5,
h * 0.5,
w, h,
ndc_coeffs=(-1.0, -1.0))
torch.cuda.synchronize()
im = grid.volume_render_image(cam, use_kernel=True)
torch.cuda.synchronize()
im.clamp_(0.0, 1.0)

im = im.cpu().numpy()
im = (im * 255).astype(np.uint8)
frames.append(im)
im = None
n_images_gen += 1
if len(frames):
vid_path = render_out_path
imageio.mimwrite(vid_path, frames, fps=args.fps, macro_block_size=8) # pip install imageio-ffmpeg


78 changes: 78 additions & 0 deletions opt/util/dataset_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch
import torch.nn.functional as F
from typing import Union, Optional, List
from .util import select_or_shuffle_rays, Rays, Intrin

class DatasetBase:
split: str
permutation: bool
epoch_size: Optional[int]
n_images: int
h_full: int
w_full: int
intrins_full: Intrin
c2w: torch.Tensor # C2W OpenCV poses
gt: Union[torch.Tensor, List[torch.Tensor]] # RGB images
device : Union[str, torch.device]

def __init__(self):
self.ndc_coeffs = (-1, -1)
self.use_sphere_bound = False
self.should_use_background = True # a hint
self.use_sphere_bound = True
self.scene_center = [0.0, 0.0, 0.0]
self.scene_radius = [1.0, 1.0, 1.0]
self.permutation = False

def shuffle_rays(self):
"""
Shuffle all rays
"""
if self.split == "train":
del self.rays
self.rays = select_or_shuffle_rays(self.rays_init, self.permutation,
self.epoch_size, self.device)

def gen_rays(self, factor=1):
print(" Generating rays, scaling factor", factor)
# Generate rays
self.factor = factor
self.h = self.h_full // factor
self.w = self.w_full // factor
true_factor = self.h_full / self.h
self.intrins = self.intrins_full.scale(1.0 / true_factor)
yy, xx = torch.meshgrid(
torch.arange(self.h, dtype=torch.float32) + 0.5,
torch.arange(self.w, dtype=torch.float32) + 0.5,
)
xx = (xx - self.intrins.cx) / self.intrins.fx
yy = (yy - self.intrins.cy) / self.intrins.fy
zz = torch.ones_like(xx)
dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention
dirs /= torch.norm(dirs, dim=-1, keepdim=True)
dirs = dirs.reshape(1, -1, 3, 1)
del xx, yy, zz
dirs = (self.c2w[:, None, :3, :3] @ dirs)[..., 0]

if factor != 1:
gt = F.interpolate(
self.gt.permute([0, 3, 1, 2]), size=(self.h, self.w), mode="area"
).permute([0, 2, 3, 1])
gt = gt.reshape(self.n_images, -1, 3)
else:
gt = self.gt.reshape(self.n_images, -1, 3)
origins = self.c2w[:, None, :3, 3].expand(-1, self.h * self.w, -1).contiguous()
if self.split == "train":
origins = origins.view(-1, 3)
dirs = dirs.view(-1, 3)
gt = gt.reshape(-1, 3)

self.rays_init = Rays(origins=origins, dirs=dirs, gt=gt)
self.rays = self.rays_init

def get_image_size(self, i : int):
# H, W
if hasattr(self, 'image_size'):
return tuple(self.image_size[i])
else:
return self.h, self.w

0 comments on commit d296bc9

Please sign in to comment.