Skip to content

Commit

Permalink
easyvolcap: control viewer init ratio for some of the datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
dendenxu committed Jan 20, 2024
1 parent 02f465c commit 96b988b
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 89 deletions.
3 changes: 3 additions & 0 deletions configs/datasets/NHR/NHR.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ val_dataloader_cfg:
<<: *dataset_cfg
sampler_cfg:
view_sample: [0, null, 20]

viewer_cfg:
use_window_focal: True
3 changes: 3 additions & 0 deletions configs/datasets/enerf_outdoor/enerf_outdoor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,6 @@ val_dataloader_cfg:
runner_cfg:
visualizer_cfg:
video_fps: 60 # this dataset id built differently

viewer_cfg:
use_window_focal: True
3 changes: 3 additions & 0 deletions configs/datasets/renbody/renbody.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,6 @@ val_dataloader_cfg:
frame_sample: [0, 150, 30]
sampler_cfg:
view_sample: [0, 60, 20]

viewer_cfg:
use_window_focal: True
33 changes: 25 additions & 8 deletions easyvolcap/runners/volumetric_video_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(self,
compose: bool = False,
compose_power: float = 1.0,
render_ratio: float = 1.0,
use_window_focal: bool = False,

fullscreen: bool = False,
camera_cfg: dotdict = dotdict(type=Camera.__name__),
Expand All @@ -100,6 +101,7 @@ def __init__(self,
self.compose = compose # composing only works with cudagl for now
self.compose_power = compose_power
self.exp_name = exp_name
self.use_window_focal = use_window_focal

self.font_default = font_default
self.font_italic = font_italic
Expand Down Expand Up @@ -1177,14 +1179,29 @@ def init_camera(self, camera_cfg: dotdict = dotdict(), view_index: int = None):
# We load the first camera out of it
dataset = self.dataset
H, W = self.window_size # dimesions
M = max(H, W)
K = torch.as_tensor([
[M * dataset.focal_ratio, 0, W / 2], # smaller focal, large fov for a bigger picture
[0, M * dataset.focal_ratio, H / 2],
[0, 0, 1],
], dtype=torch.float)
if view_index is None: R, T = dataset.Rv.clone(), dataset.Tv.clone() # intrinsics and extrinsics
else: R, T = dataset.Rs[view_index, 0], dataset.Ts[view_index, 0]

if self.use_window_focal or not hasattr(dataset, 'Ks'):
M = max(H, W)
K = torch.as_tensor([
[M * dataset.focal_ratio, 0, W / 2], # smaller focal, large fov for a bigger picture
[0, M * dataset.focal_ratio, H / 2],
[0, 0, 1],
], dtype=torch.float)
else:
if view_index is None:
K = dataset.Ks[0, 0].clone()
K[0:1] *= W / dataset.Ws[0, 0]
K[1:2] *= H / dataset.Hs[0, 0]
else:
K = dataset.Ks[view_index, 0].clone()
K[0:1] *= W / dataset.Ws[view_index, 0]
K[1:2] *= H / dataset.Hs[view_index, 0]

if view_index is None:
R, T = dataset.Rv.clone(), dataset.Tv.clone() # intrinsics and extrinsics
else:
R, T = dataset.Rs[view_index, 0], dataset.Ts[view_index, 0]

n, f, t, v = dataset.near, dataset.far, 0, 0 # use 0 for default t
bounds = dataset.bounds.clone() # avoids modification
self.camera = Camera(H, W, K, R, T, n, f, t, v, bounds, **camera_cfg)
Expand Down
30 changes: 18 additions & 12 deletions easyvolcap/utils/enerf_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from torch import nn
from typing import Union, List
from typing import Union, List, Tuple
from torch.nn import functional as F
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd
Expand Down Expand Up @@ -186,16 +186,19 @@ def sample_feature_volume(s_vals: torch.Tensor, feat_vol: torch.Tensor, ren_scal


@torch.jit.script
def depth_regression(depth_prob: torch.Tensor, depth_values: torch.Tensor, volume_render_depth: bool = False):
def depth_regression(depth_prob: torch.Tensor, depth_values: torch.Tensor, volume_render_depth: bool = False, use_dist: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
# depth_prob: B, D, H, W
# depth_values: B, D, H, W

if volume_render_depth:
B, D, H, W = depth_prob.shape
raws = depth_prob.permute(0, 2, 3, 1).reshape(B, H * W, D) # B, H, W, D -> B, HW, D
z_vals = depth_values.permute(0, 2, 3, 1).reshape(B, H * W, D) # B, H, W, D -> B, HW, D
dists = compute_dist(z_vals) # B, HW, D
occ = 1. - torch.exp(-raws * dists) # B, HW, D
if use_dist:
dists = compute_dist(z_vals) # B, HW, D
occ = 1. - torch.exp(-raws * dists) # B, HW, D
else:
occ = 1. - torch.exp(-raws) # B, HW, D
weights = render_weights(occ) # B, HW, D
acc_map = torch.sum(weights, -1, keepdim=True) # (B, HW, 1)
depth = weighted_percentile(torch.cat([z_vals, z_vals.max(dim=-1, keepdim=True)[0]], dim=-1),
Expand All @@ -220,8 +223,8 @@ def weight_regression(depth_prob: torch.Tensor, depth_values: torch.Tensor = Non
z_vals = depth_values.permute(0, 2, 3, 1).view(B, H * W, D) # B, H, W, D -> B, HW, D
dists = compute_dist(z_vals) # B, HW, D
occ = 1. - torch.exp(-raws * dists) # B, HW, D

occ = 1. - torch.exp(-raws) # B, HW, D
else:
occ = 1. - torch.exp(-raws) # B, HW, D
weights = render_weights(occ) # B, HW, D
return weights.view(B, H, W, D).permute(0, 3, 1, 2)

Expand Down Expand Up @@ -449,9 +452,10 @@ def train(self, mode: bool):
@REGRESSORS.register_module()
class CostRegNet(nn.Module):
# TODO: compare the results of nn.BatchNorm3d and nn.InstanceNorm3d
def __init__(self, in_channels, norm_actvn=nn.BatchNorm3d, out_actvn=nn.Identity(), use_vox_feat=True):
def __init__(self, in_channels, norm_actvn=nn.BatchNorm3d, dpt_actvn=nn.Identity, use_vox_feat=True):
super(CostRegNet, self).__init__()
norm_actvn = getattr(nn, norm_actvn) if isinstance(norm_actvn, str) else norm_actvn
self.dpt_actvn = get_function(dpt_actvn)

self.conv0 = ConvBnReLU3D(in_channels, 8, norm_actvn=norm_actvn)

Expand Down Expand Up @@ -486,7 +490,6 @@ def __init__(self, in_channels, norm_actvn=nn.BatchNorm3d, out_actvn=nn.Identity

self.size_pad = 8 # input size should be divisible by 4
self.out_dim = 8
self.out_actvn = get_function(out_actvn) if isinstance(out_actvn, str) else out_actvn

def forward(self, x: torch.Tensor):
conv0 = self.conv0(x)
Expand All @@ -500,7 +503,7 @@ def forward(self, x: torch.Tensor):
x = conv0 + self.conv11(x)
del conv0
depth = self.depth_conv(x)
depth = self.out_actvn(depth.squeeze(1)) # softplus might change dtype
depth = self.dpt_actvn(depth.squeeze(1)) # softplus might change dtype

if self.use_vox_feat:
feat = self.feat_conv(x)
Expand All @@ -511,9 +514,10 @@ def forward(self, x: torch.Tensor):

@REGRESSORS.register_module()
class MinCostRegNet(nn.Module):
def __init__(self, in_channels, norm_actvn=nn.BatchNorm3d):
def __init__(self, in_channels, norm_actvn=nn.BatchNorm3d, dpt_actvn=nn.Identity):
super(MinCostRegNet, self).__init__()
norm_actvn = getattr(nn, norm_actvn) if isinstance(norm_actvn, str) else norm_actvn
self.dpt_actvn = get_function(dpt_actvn)

self.conv0 = ConvBnReLU3D(in_channels, 8, norm_actvn=norm_actvn)

Expand All @@ -535,7 +539,7 @@ def __init__(self, in_channels, norm_actvn=nn.BatchNorm3d):

self.depth_conv = nn.Sequential(nn.Conv3d(8, 1, 3, padding=1, bias=False))
self.feat_conv = nn.Sequential(nn.Conv3d(8, 8, 3, padding=1, bias=False))

self.out_dim = 8
self.size_pad = 4 # input should be divisible by 4

def forward(self, x, use_vox_feat=True):
Expand All @@ -548,9 +552,11 @@ def forward(self, x, use_vox_feat=True):
x = conv0 + self.conv11(x)
del conv0
depth = self.depth_conv(x)
depth = self.dpt_actvn(depth.squeeze(1))

if not use_vox_feat: feat = None
else: feat = self.feat_conv(x)
return feat, depth.squeeze(1)
return feat, depth


# ? This could be refactored
Expand Down
26 changes: 26 additions & 0 deletions easyvolcap/utils/fusion_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
from torch.nn import functional as F

from easyvolcap.utils.console_utils import *
from easyvolcap.utils.chunk_utils import multi_gather
from easyvolcap.utils.fcds_utils import remove_outlier


def filter_global_points(points: dotdict[str, torch.Tensor] = dotdict()):

def gather_from_inds(ind: torch.Tensor, scalars: dotdict()):
return dotdict({k: multi_gather(v, ind[..., None]) for k, v in scalars.items()})

# Remove NaNs in point positions
ind = (~points.pts.isnan())[..., 0].nonzero()[..., 0] # P,
points = gather_from_inds(ind, points)

# Remove low density points
ind = (points.occ > 0.01)[..., 0].nonzero()[..., 0] # P,
points = gather_from_inds(ind, points)

# Remove statistic outliers (il_ind -> inlier indices)
ind = remove_outlier(points.pts[None], K=50, std_ratio=4.0, return_inds=True)[0] # P,
points = gather_from_inds(ind, points)

return points
88 changes: 19 additions & 69 deletions scripts/tools/volume_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from easyvolcap.utils.data_utils import add_batch, to_cuda, export_pts, export_mesh, export_pcd, to_x
from easyvolcap.utils.math_utils import point_padding, affine_padding, affine_inverse
from easyvolcap.utils.chunk_utils import multi_gather, multi_scatter
from easyvolcap.utils.fusion_utils import filter_global_points

from typing import TYPE_CHECKING
if TYPE_CHECKING:
Expand Down Expand Up @@ -74,29 +75,28 @@ def fuse(runner: "VolumetricVideoRunner", args: argparse.Namespace):
# Handle data movement
batch = dataset[inds[v, f]] # get the batch data for this view
batch = add_batch(to_cuda(batch))
meta = batch.meta
del batch.meta
batch.meta = meta

# Running inference
with torch.inference_mode(), torch.no_grad():
output = runner.model(batch) # get everything we need from the model, this performs the actual rendering

# Get output point clouds
pts = (batch.ray_o + output.dpt_map * batch.ray_d).detach().cpu()
rgb = batch.rgb.detach().cpu()
prd = output.rgb_map.detach().cpu()
occ = output.acc_map.detach().cpu()
dpt = output.dpt_map.detach().cpu()
dir = batch.ray_d.detach().cpu()
pts = (batch.ray_o + output.dpt_map * batch.ray_d)[0]
rgb = batch.rgb[0]
prd = output.rgb_map[0]
occ = output.acc_map[0]
dpt = output.dpt_map[0]
dir = batch.ray_d[0]

# Filter local points

# Store it into list
prds.append(prd)
ptss.append(pts)
rgbs.append(rgb)
occs.append(occ)
dpts.append(dpt)
dirs.append(dir)
prds.append(prd.detach().cpu()[0]) # reduce memory usage
ptss.append(pts.detach().cpu()[0]) # reduce memory usage
rgbs.append(rgb.detach().cpu()[0]) # reduce memory usage
occs.append(occ.detach().cpu()[0]) # reduce memory usage
dpts.append(dpt.detach().cpu()[0]) # reduce memory usage
dirs.append(dir.detach().cpu()[0]) # reduce memory usage

pbar.update()

Expand All @@ -108,60 +108,10 @@ def fuse(runner: "VolumetricVideoRunner", args: argparse.Namespace):
dpt = torch.cat(dpts, dim=-2).float()
dir = torch.cat(dirs, dim=-2).float()

# Remove NaNs in point positions
ind = (~pts.isnan())[0, ..., 0].nonzero()[..., 0] # P,
log(f'Removing NaNs: {pts.shape[1] - len(ind)}')
prd = multi_gather(prd, ind[None, ..., None]) # B, P, C
pts = multi_gather(pts, ind[None, ..., None]) # B, P, C
rgb = multi_gather(rgb, ind[None, ..., None]) # B, P, C
occ = multi_gather(occ, ind[None, ..., None]) # B, P, C
dpt = multi_gather(dpt, ind[None, ..., None]) # B, P, C
dir = multi_gather(dir, ind[None, ..., None]) # B, P, C

# Remove low density points
if not args.skip_density:
ind = (occ > args.occ_thresh)[0, ..., 0].nonzero()[..., 0] # P,
log(f'Removing low density points: {pts.shape[1] - len(ind)}')
prd = multi_gather(prd, ind[None, ..., None]) # B, P, C
pts = multi_gather(pts, ind[None, ..., None]) # B, P, C
rgb = multi_gather(rgb, ind[None, ..., None]) # B, P, C
occ = multi_gather(occ, ind[None, ..., None]) # B, P, C
dpt = multi_gather(dpt, ind[None, ..., None]) # B, P, C
dir = multi_gather(dir, ind[None, ..., None]) # B, P, C

# Remove statistic outliers (il_ind -> inlier indices)
if not args.skip_outlier:
ind = remove_outlier(pts, K=50, std_ratio=4.0, return_inds=True)[0] # P,
log(f'Removing outliers: {pts.shape[1] - len(ind)}')
prd = multi_gather(prd, ind[None, ..., None]) # B, P, C
pts = multi_gather(pts, ind[None, ..., None]) # B, P, C
rgb = multi_gather(rgb, ind[None, ..., None]) # B, P, C
occ = multi_gather(occ, ind[None, ..., None]) # B, P, C
dpt = multi_gather(dpt, ind[None, ..., None]) # B, P, C
dir = multi_gather(dir, ind[None, ..., None]) # B, P, C

# Remove points outside of the near far bounds
if not args.skip_near_far:
near, far = dataset.near, dataset.far # scalar for controlling camera near far
near_far_mask = pts.new_ones(pts.shape[1:-1], dtype=torch.bool)
for v in range(nv):
batch = dataset[inds[v, f]] # get the batch data for this view
H, W, K, R, T = batch.H, batch.W, batch.K, batch.R, batch.T
pts_view = pts @ R.mT + T.mT
pts_pix = pts_view @ K.mT # N, 3
pix = pts_pix[..., :2] / pts_pix[..., 2:]
pix = pix / pix.new_tensor([W, H]) * 2 - 1 # N, P, 2 to sample the msk (dimensionality normalization for sampling)
outside = ((pix[0] < -1) | (pix[0] > 1)).any(dim=-1) # P,
near_far = ((pts_view[0, ..., -1] < far + args.near_far_pad) & (pts_view[0, ..., -1] > near - args.near_far_pad)) # P,
near_far_mask &= near_far | outside
ind = near_far_mask.nonzero()[..., 0]
log(f'Removing out-of-near-far points: {pts.shape[1] - len(ind)}')
prd = multi_gather(prd, ind[None, ..., None]) # B, P, C
pts = multi_gather(pts, ind[None, ..., None]) # B, P, C
rgb = multi_gather(rgb, ind[None, ..., None]) # B, P, C
occ = multi_gather(occ, ind[None, ..., None]) # B, P, C
dpt = multi_gather(dpt, ind[None, ..., None]) # B, P, C
dir = multi_gather(dir, ind[None, ..., None]) # B, P, C
# Apply some global filtering
points = filter_global_points(dotdict(prd=prd, pts=pts, rgb=rgb, occ=occ, dpt=dpt, dir=dir))
log(f'Filtered {len(pts)} - {len(points.pts)} = {len(pts) - len(points.pts)} points globally')
prd, pts, rgb, occ, dpt, dir = points.prd, points.pts, points.rgb, points.occ, points.dpt, points.dir

# Align point cloud with the average camera, which is processed in memory, to make sure the stored files are consistent
if dataset.use_aligned_cameras and not args.skip_align: # match the visual hull implementation
Expand Down

0 comments on commit 96b988b

Please sign in to comment.