Skip to content

Commit

Permalink
easyvolcap: add geometry consistentcy in the depth fusion process & l…
Browse files Browse the repository at this point in the history
…azy load perc loss modules
  • Loading branch information
dendenxu committed Jan 20, 2024
1 parent cccda9d commit 2b910be
Show file tree
Hide file tree
Showing 5 changed files with 281 additions and 57 deletions.
7 changes: 4 additions & 3 deletions easyvolcap/models/supervisors/volumetric_video_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ def __init__(self,
self.perc_loss_weight = perc_loss_weight
self.ssim_loss_weight = ssim_loss_weight

# For computing perceptual loss & img_loss # HACK: Nobody is referencing this as a pytorch module
if self.perc_loss_weight > 0: self.perc_loss_reference = [VGGPerceptualLoss().cuda().to(self.dtype)] # move to specific location

@property
def perc_loss(self):
return self.perc_loss_reference[0]
Expand All @@ -51,6 +48,10 @@ def compute_image_loss(self, rgb_map: torch.Tensor, rgb_gt: torch.Tensor,
psnr = (1 / mse.clip(1e-10)).log() * 10 / np.log(10)

if type == ImgLossType.PERC:
if not hasattr(self, 'perc_loss_reference'):
# For computing perceptual loss & img_loss # HACK: Nobody is referencing this as a pytorch module
log('Initializing VGGPerceptualLoss')
self.perc_loss_reference = [VGGPerceptualLoss().cuda().to(self.dtype)] # move to specific location
rgb_gt = rgb_gt.view(-1, H, W, 3).permute(0, 3, 1, 2) # B, C, H, W
rgb_map = rgb_map.view(-1, H, W, 3).permute(0, 3, 1, 2) # B, C, H, W
img_loss = self.perc_loss(rgb_map, rgb_gt)
Expand Down
10 changes: 6 additions & 4 deletions easyvolcap/runners/evaluators/volumetric_video_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ def psnr(x: torch.Tensor, y: torch.Tensor):
def ssim(xs: torch.Tensor, ys: torch.Tensor):
return np.mean([compare_ssim(x.detach().cpu().numpy(), y.detach().cpu().numpy(), channel_axis=-1, data_range=2.0) for x, y in zip(xs, ys)])

import lpips as lpips_module
compute_lpips = lpips_module.LPIPS(net='vgg', verbose=False).cuda()

def lpips(x: torch.Tensor, y: torch.Tensor):
# B, H, W, 3
# B, H, W, 3
return compute_lpips(x.permute(0, 3, 1, 2) * 2 - 1, y.permute(0, 3, 1, 2) * 2 - 1).mean().item()
if not hasattr(self, 'compute_lpips'):
import lpips as lpips_module
log('Initializing LPIPS network')
self.compute_lpips = lpips_module.LPIPS(net='vgg', verbose=False).cuda()

return self.compute_lpips(x.permute(0, 3, 1, 2) * 2 - 1, y.permute(0, 3, 1, 2) * 2 - 1).mean().item()

self.compute_metrics = [psnr, ssim, lpips]

Expand Down
19 changes: 19 additions & 0 deletions easyvolcap/utils/cam_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import cv2
import json
import torch
import numpy as np
from typing import Union
from enum import Enum, auto
Expand All @@ -11,6 +12,24 @@

from easyvolcap.utils.console_utils import *
from easyvolcap.utils.data_utils import get_rays, get_near_far
from easyvolcap.utils.math_utils import affine_inverse, affine_padding


def compute_camera_similarity(tar_c2ws: torch.Tensor, src_c2ws: torch.Tensor):
# c2ws = affine_inverse(w2cs) # N, L, 3, 4
# src_exts = affine_padding(w2cs) # N, L, 4, 4

# tar_c2ws = c2ws
# src_c2ws = affine_inverse(src_exts)
centers_target = tar_c2ws[..., :3, 3] # N, L, 3
centers_source = src_c2ws[..., :3, 3] # N, L, 3

# Using distance between centers for camera selection
sims: torch.Tensor = 1 / (centers_source[None] - centers_target[:, None]).norm(dim=-1) # N, N, L,

# Source view index and there similarity
src_sims, src_inds = sims.sort(dim=1, descending=True) # similarity to source views # Target, Source, Latent
return src_sims, src_inds # N, N, L


class Interpolation(Enum):
Expand Down
189 changes: 163 additions & 26 deletions easyvolcap/utils/fusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from easyvolcap.utils.console_utils import *
from easyvolcap.utils.chunk_utils import multi_gather
from easyvolcap.utils.fcds_utils import remove_outlier
from easyvolcap.utils.ray_utils import create_meshgrid
from easyvolcap.utils.math_utils import affine_inverse, affine_padding


def filter_global_points(points: dotdict[str, torch.Tensor] = dotdict()):
def filter_global_points(points: dotdict[str, torch.Tensor]):

def gather_from_inds(ind: torch.Tensor, scalars: dotdict()):
return dotdict({k: multi_gather(v, ind[..., None]) for k, v in scalars.items()})
Expand All @@ -30,6 +32,142 @@ def gather_from_inds(ind: torch.Tensor, scalars: dotdict()):

return points

# *************************************
# Our PyTorch implementation
# *************************************


def depth_reprojection(dpt_ref: torch.Tensor, # B, H, W
ixt_ref: torch.Tensor, # B, 3, 3
ext_ref: torch.Tensor, # B, 4, 4
dpt_src: torch.Tensor, # B, S, H, W
ixt_src: torch.Tensor, # B, S, 3, 3
ext_src: torch.Tensor, # B, S, 4, 4
):
sh = dpt_ref.shape[:-2]
H, W = dpt_ref.shape[-2:]

# step1. project reference pixels to the source view
# reference view x, y
ij_ref = create_meshgrid(H, W, device=dpt_ref.device, dtype=dpt_ref.dtype)
xy_ref = ij_ref.flip(-1)
for _ in range(len(sh)): xy_ref = xy_ref[None] # add dimension
xy_ref = xy_ref.view(*sh, -1, 2) # B, HW, 2
x_ref, y_ref = xy_ref.split(1, dim=-1) # B, HW, 1

xy1_ref = torch.cat([xy_ref, torch.ones_like(xy_ref[..., :1])], dim=-1) # B, HW, 3
# reference 3D space
xyz_ref = xy1_ref @ ixt_ref.inverse().mT * dpt_ref.view(*sh, H * W, 1) # B, HW, 3

# source 3D space
xyz1_ref = torch.cat([xyz_ref, torch.ones_like(xyz_ref[..., :1])], dim=-1) # B, HW, 4
xyz_src = ((xyz1_ref @ affine_inverse(ext_ref).mT).unsqueeze(-3) @ ext_src.mT)[..., :3] # B, S, HW, 3
# source view x, y
K_xyz_src = xyz_src @ ixt_src.mT # B, S, HW, 3
xy_src = K_xyz_src[..., :2] / K_xyz_src[..., 2:3] # homography reprojection, B, S, HW, 2

# step2. reproject the source view points with source view depth estimation
# find the depth estimation of the source view
x_src = xy_src[..., 0].view(*sh, -1, H, W) # B, S, H, W
y_src = xy_src[..., 1].view(*sh, -1, H, W) # B, S, H, W

dpt_input = dpt_src.unsqueeze(-3) # B, S, 1, H, W
xy_grid = torch.stack([x_src / W * 2 - 1, y_src / H * 2 - 1], dim=-1) # B, S, H, W, 2
bs = dpt_input.shape[:-3] # BS
dpt_input = dpt_input.view(-1, *dpt_input.shape[-3:]) # BS, H, W, 2
xy_grid = xy_grid.view(-1, *xy_grid.shape[-3:]) # BS, 1, H, W
sampled_depth_src = F.grid_sample(dpt_input, xy_grid, padding_mode='border') # BS, 1, H, W # sampled src depth map
sampled_depth_src = sampled_depth_src.view(*bs, H, W) # B, S, H, W
# mask = sampled_depth_src > 0

# source 3D space
# NOTE that we should use sampled source-view depth_here to project back
xy1_src = torch.cat([xy_src, torch.ones_like(xy_src[..., :1])], dim=-1) # B, S, HW, 3
xyz_src = (xy1_src @ ixt_src.inverse().mT) * sampled_depth_src.view(*sh, -1, H * W, 1) # B, S, HW, 3

# reference 3D space
xyz1_src = torch.cat([xyz_src, torch.ones_like(xyz_src[..., :1])], dim=-1) # B, S, HW, 4
xyz_reprojected = ((xyz1_src @ affine_inverse(ext_src).mT) @ ext_ref.mT.unsqueeze(-3))[..., :3] # B, S, HW, 3

# source view x, y, depth
depth_reprojected = xyz_reprojected[..., 2].view(*sh, -1, H, W) # source depth in ref view space, B, S, H, W
K_xyz_reprojected = xyz_reprojected @ ixt_ref.unsqueeze(-3).mT # B, S, HW, 3 # source xyz in ref screen space
xy_reprojected = K_xyz_reprojected[..., :2] / K_xyz_reprojected[..., 2:3] # homography
x_reprojected = xy_reprojected[..., 0].view(*sh, -1, H, W) # source point in ref screen space, x: B, S, H, W
y_reprojected = xy_reprojected[..., 1].view(*sh, -1, H, W) # source point in ref screen space, y: B, S, H, W

return depth_reprojected, x_reprojected, y_reprojected, x_src, y_src


def depth_geometry_consistency(
dpt_ref: torch.Tensor, # B, H, W
ixt_ref: torch.Tensor, # B, 3, 3
ext_ref: torch.Tensor, # B, 4, 4
dpt_src: torch.Tensor, # B, S, H, W
ixt_src: torch.Tensor, # B, S, 3, 3
ext_src: torch.Tensor, # B, S, 4, 4

geo_abs_thresh: float = 0.5,
geo_rel_thresh: float = 0.01,
):
# Assumption: same sized image
# Assumption: zero depth -> no depth, should be masked out as photometrically inconsistent
sh = dpt_ref.shape[:-2]
H, W = dpt_ref.shape[-2:]

# step1. project reference pixels to the source view
# reference view x, y
ij_ref = create_meshgrid(H, W, device=dpt_ref.device, dtype=dpt_ref.dtype)
xy_ref = ij_ref.flip(-1)
for _ in range(len(sh)): xy_ref = xy_ref[None] # add dimension
xy_ref = xy_ref.view(*sh, H, W, 2) # B, H, W, 2
x_ref, y_ref = xy_ref.unbind(-1) # B, H, W
x_ref: torch.Tensor # add type annotation
y_ref: torch.Tensor # add type annotation

depth_reprojected, x2d_reprojected, y2d_reprojected, x2d_src, y2d_src = depth_reprojection(
dpt_ref, ixt_ref, ext_ref,
dpt_src, ixt_src, ext_src)
# check |p_reproj-p_1| < 1
dist = torch.sqrt((x2d_reprojected - x_ref.unsqueeze(-3)) ** 2 + (y2d_reprojected - y_ref.unsqueeze(-3)) ** 2) # B, S, H, W

# check |d_reproj-d_1| / d_1 < 0.01
depth_diff = torch.abs(depth_reprojected - dpt_ref.unsqueeze(-3)) # unprojected depth difference
relative_depth_diff = depth_diff / dpt_ref # relative unprojected depth difference

mask = torch.logical_and(dist < geo_abs_thresh, relative_depth_diff < geo_rel_thresh) # smaller than 0.5 pix, relative smaller than 0.01
depth_reprojected[~mask] = 0 # those are valid points

return mask, depth_reprojected, x2d_src, y2d_src


def compute_consistency(
dpt_ref: torch.Tensor, # B, H, W
ixt_ref: torch.Tensor, # B, 3, 3
ext_ref: torch.Tensor, # B, 4, 4
dpt_src: torch.Tensor, # B, S, H, W
ixt_src: torch.Tensor, # B, S, 3, 3
ext_src: torch.Tensor, # B, S, 4, 4

geo_abs_thresh: float = 0.5,
geo_rel_thresh: float = 0.01,
pho_abs_thresh: float = 0.0,
):
# Perform actual geometry consistency check
geo_mask, depth_reprojected, x2d_src, y2d_src = depth_geometry_consistency(
dpt_ref, ixt_ref, ext_ref, dpt_src, ixt_src, ext_src,
geo_abs_thresh, geo_rel_thresh
) # H, W; 4, H, W; 4, H, W; 4, H, W

# Aggregate the projected mask
geo_mask_sum = geo_mask.sum(-3) # H, W
depth_est_averaged = (depth_reprojected.sum(-3) + dpt_ref) / (geo_mask_sum + 1) # average depth values, H, W
# at least 3 source views matched
geo_mask = geo_mask_sum >= 3 # a pixel is considered valid when at least 3 sources matches up
photo_mask = dpt_ref >= pho_abs_thresh
final_mask = torch.logical_and(photo_mask, geo_mask) # H, W

return depth_est_averaged, photo_mask, geo_mask, final_mask

# *************************************
# Original implementation in cvp-mvsnet
Expand Down Expand Up @@ -185,38 +323,38 @@ def reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, i
width, height = depth_ref.shape[1], depth_ref.shape[0]
# step1. project reference pixels to the source view
# reference view x, y
x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height))
x_ref, y_ref = x_ref.reshape([-1]), y_ref.reshape([-1])
x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height)) # int
x_ref, y_ref = x_ref.reshape([-1]), y_ref.reshape([-1]) # int
# reference 3D space
xyz_ref = np.matmul(np.linalg.inv(intrinsics_ref),
np.vstack((x_ref, y_ref, np.ones_like(x_ref))) * depth_ref.reshape([-1]))
np.vstack((x_ref, y_ref, np.ones_like(x_ref))) * depth_ref.reshape([-1])) # view space ref xyz 2, HW
# source 3D space
xyz_src = np.matmul(np.matmul(extrinsics_src, np.linalg.inv(extrinsics_ref)),
np.vstack((xyz_ref, np.ones_like(x_ref))))[:3]
np.vstack((xyz_ref, np.ones_like(x_ref))))[:3] # ref xyz in src view space 2, HW
# source view x, y
K_xyz_src = np.matmul(intrinsics_src, xyz_src)
xy_src = K_xyz_src[:2] / K_xyz_src[2:3]
K_xyz_src = np.matmul(intrinsics_src, xyz_src) # ref xyz in src screen space
xy_src = K_xyz_src[:2] / K_xyz_src[2:3] # homography reprojection, 2, HW

# step2. reproject the source view points with source view depth estimation
# find the depth estimation of the source view
x_src = xy_src[0].reshape([height, width]).astype(np.float32)
y_src = xy_src[1].reshape([height, width]).astype(np.float32)
sampled_depth_src = cv2.remap(depth_src, x_src, y_src, interpolation=cv2.INTER_LINEAR)
x_src = xy_src[0].reshape([height, width]).astype(np.float32) # H, W
y_src = xy_src[1].reshape([height, width]).astype(np.float32) # H, W
sampled_depth_src = cv2.remap(depth_src, x_src, y_src, interpolation=cv2.INTER_LINEAR) # sampled src depth map
# mask = sampled_depth_src > 0

# source 3D space
# NOTE that we should use sampled source-view depth_here to project back
xyz_src = np.matmul(np.linalg.inv(intrinsics_src),
np.vstack((xy_src, np.ones_like(x_ref))) * sampled_depth_src.reshape([-1]))
np.vstack((xy_src, np.ones_like(x_ref))) * sampled_depth_src.reshape([-1])) # unproject back to src view space, 3, HW
# reference 3D space
xyz_reprojected = np.matmul(np.matmul(extrinsics_ref, np.linalg.inv(extrinsics_src)),
np.vstack((xyz_src, np.ones_like(x_ref))))[:3]
np.vstack((xyz_src, np.ones_like(x_ref))))[:3] # unprojected points back to ref view space, 3, HW
# source view x, y, depth
depth_reprojected = xyz_reprojected[2].reshape([height, width]).astype(np.float32)
K_xyz_reprojected = np.matmul(intrinsics_ref, xyz_reprojected)
xy_reprojected = K_xyz_reprojected[:2] / K_xyz_reprojected[2:3]
x_reprojected = xy_reprojected[0].reshape([height, width]).astype(np.float32)
y_reprojected = xy_reprojected[1].reshape([height, width]).astype(np.float32)
depth_reprojected = xyz_reprojected[2].reshape([height, width]).astype(np.float32) # source depth in ref view space, H, W
K_xyz_reprojected = np.matmul(intrinsics_ref, xyz_reprojected) # source xyz in ref screen space, 3, HW
xy_reprojected = K_xyz_reprojected[:2] / K_xyz_reprojected[2:3] # homography
x_reprojected = xy_reprojected[0].reshape([height, width]).astype(np.float32) # source point in ref screen space, x: H, W
y_reprojected = xy_reprojected[1].reshape([height, width]).astype(np.float32) # source point in ref screen space, y: H, W

return depth_reprojected, x_reprojected, y_reprojected, x_src, y_src

Expand All @@ -230,11 +368,11 @@ def check_geometric_consistency(depth_ref, intrinsics_ref, extrinsics_ref, depth
dist = np.sqrt((x2d_reprojected - x_ref) ** 2 + (y2d_reprojected - y_ref) ** 2)

# check |d_reproj-d_1| / d_1 < 0.01
depth_diff = np.abs(depth_reprojected - depth_ref)
relative_depth_diff = depth_diff / depth_ref
depth_diff = np.abs(depth_reprojected - depth_ref) # unprojected depth difference
relative_depth_diff = depth_diff / depth_ref # relative unprojected depth difference

mask = np.logical_and(dist < 0.5, relative_depth_diff < 0.01)
depth_reprojected[~mask] = 0
mask = np.logical_and(dist < 0.5, relative_depth_diff < 0.01) # smaller than 0.5 pix, relative smaller than 0.01
depth_reprojected[~mask] = 0 # those are valid points

return mask, depth_reprojected, x2d_src, y2d_src

Expand Down Expand Up @@ -282,18 +420,17 @@ def filter_depth(dataset_root, scan, out_folder, plyfilename):
src_depth_est, scale = read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(src_view)))

geo_mask, depth_reprojected, x2d_src, y2d_src = check_geometric_consistency(ref_depth_est, ref_intrinsics, ref_extrinsics,
src_depth_est,
src_intrinsics, src_extrinsics)
src_depth_est, src_intrinsics, src_extrinsics)

geo_mask_sum += geo_mask.astype(np.int32)
all_srcview_depth_ests.append(depth_reprojected)
all_srcview_x.append(x2d_src)
all_srcview_y.append(y2d_src)
all_srcview_geomask.append(geo_mask)

depth_est_averaged = (sum(all_srcview_depth_ests) + ref_depth_est) / (geo_mask_sum + 1)
depth_est_averaged = (sum(all_srcview_depth_ests) + ref_depth_est) / (geo_mask_sum + 1) # average depth values
# at least 3 source views matched
geo_mask = geo_mask_sum >= 3
geo_mask = geo_mask_sum >= 3 # a pixel is considered valid when at least 3 sources matches up
final_mask = np.logical_and(photo_mask, geo_mask)

os.makedirs(os.path.join(out_folder, "mask"), exist_ok=True)
Expand All @@ -318,7 +455,7 @@ def filter_depth(dataset_root, scan, out_folder, plyfilename):
xyz_ref = np.matmul(np.linalg.inv(ref_intrinsics),
np.vstack((x, y, np.ones_like(x))) * depth)
xyz_world = np.matmul(np.linalg.inv(ref_extrinsics),
np.vstack((xyz_ref, np.ones_like(x))))[:3]
np.vstack((xyz_ref, np.ones_like(x))))[:3] # convert valid depths to world space
vertexs.append(xyz_world.transpose((1, 0)))
vertex_colors.append((color).astype(np.uint8))

Expand Down
Loading

0 comments on commit 2b910be

Please sign in to comment.