Skip to content

Commit

Permalink
easyvolcap: fix depth fusion's typo bug
Browse files Browse the repository at this point in the history
  • Loading branch information
dendenxu committed Feb 28, 2024
1 parent 6752dc8 commit 3bfad2e
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 68 deletions.
73 changes: 16 additions & 57 deletions easyvolcap/dataloaders/datasets/volumetric_video_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ def load_bytes(self):

# Precrop image to bytes
if self.immask_crop: # a little bit wasteful but acceptable for now
# TODO: Assert that mask crop is always more aggressive than bounds crop (intersection of interested area)
self.orig_hs, self.orig_ws = self.Hs, self.Ws
bounds = [self.get_bounds(i) for i in range(self.n_latents)] # N, 2, 3
bounds = torch.stack(bounds)[None].repeat(self.n_views, 1, 1, 1) # V, N, 2, 3
Expand Down Expand Up @@ -1080,17 +1081,9 @@ def get_ground_truth(self, index):
output = self.get_metadata(index)
rgb, msk, wet, dpt, bkg = self.get_image(output.view_index, output.latent_index) # H, W, 3
H, W = rgb.shape[:2]
output.rgb = rgb.view(-1, 3) # full image in case you need it
output.msk = msk.view(-1, 1) # full mask (weights)
output.wet = wet.view(-1, 1) # full mask (weights)
if dpt is not None: output.dpt = dpt.view(-1, 1) # full depth image
if bkg is not None: output.bkg = bkg.view(-1, 3) # full background image

# Maybe crop images
if self.imbound_crop: # crop_x has already been set by imbound_crop for ixts
output = self.crop_imgs_bounds(output) # only crop target imgs
H, W = output.H.item(), output.W.item()
elif self.immask_crop: # these variables are only available when loading gts
if self.immask_crop: # these variables are only available when loading gts
meta = dotdict()
meta.crop_x = self.crop_xs[output.view_index, output.latent_index]
meta.crop_y = self.crop_ys[output.view_index, output.latent_index]
Expand All @@ -1099,6 +1092,15 @@ def get_ground_truth(self, index):
output.update(meta)
output.meta.update(meta)

if self.imbound_crop and not self.immask_crop: # crop_x has already been set by imbound_crop for ixts
x, y, w, h = output.crop_x, output.crop_y, output.W, output.H
rgb = rgb[y:y + h, x:x + w]
msk = msk[y:y + h, x:x + w]
wet = wet[y:y + h, x:x + w]
if dpt is not None: dpt = dpt[y:y + h, x:x + w]
if bkg is not None: bkg = bkg[y:y + h, x:x + w]
H, W = h, w

# FIXME: Should add mutex to protect this, for now, multi-process and dataloading doesn't work well with each other
# If Moderators are used, should set num_workers to 0 for single-process data loading
n_rays = self.n_rays
Expand All @@ -1112,12 +1114,6 @@ def get_ground_truth(self, index):
render_ratio[output.view_index] != 1.0) or \
render_ratio != 1.0:
render_ratio = self.render_ratio[output.view_index] if len(self.render_ratio.shape) else self.render_ratio
H, W = output.H.item(), output.W.item()
rgb = output.rgb.view(H, W, 3)
msk = output.msk.view(H, W, 1)
wet = output.wet.view(H, W, 1)
if dpt is not None: dpt = output.dpt.view(H, W, 1)
if bkg is not None: bkg = output.bkg.view(H, W, 3)

output = self.scale_ixts(output, render_ratio)
H, W = output.H.item(), output.W.item()
Expand All @@ -1128,23 +1124,11 @@ def get_ground_truth(self, index):
if dpt is not None: as_torch_func(partial(cv2.resize, dsize=(W, H), interpolation=cv2.INTER_AREA))(dpt)
if bkg is not None: as_torch_func(partial(cv2.resize, dsize=(W, H), interpolation=cv2.INTER_AREA))(bkg)

output.rgb = rgb.reshape(-1, 3) # full image in case you need it
output.msk = msk.reshape(-1, 1) # full mask (weights)
output.wet = wet.reshape(-1, 1) # full mask (weights)
if dpt is not None: output.dpt = dpt.reshape(-1, 1)
if bkg is not None: output.bkg = bkg.reshape(-1, 1)

# Prepare for a different rendering center crop ratio
if (len(render_center_crop_ratio.shape) and # avoid length of 0-d tensor error, check length of shape
render_center_crop_ratio[output.view_index] != 1.0) or \
render_center_crop_ratio != 1.0:
render_center_crop_ratio = self.render_center_crop_ratio[output.view_index] if len(self.render_center_crop_ratio.shape) else self.render_center_crop_ratio
H, W = output.H.item(), output.W.item()
rgb = output.rgb.view(H, W, 3)
msk = output.msk.view(H, W, 1)
wet = output.wet.view(H, W, 1)
if dpt is not None: dpt = output.dpt.view(H, W, 1)
if bkg is not None: bkg = output.bkg.view(H, W, 3)

w, h = int(W * render_center_crop_ratio), int(H * render_center_crop_ratio)
x, y = w // 2, h // 2
Expand All @@ -1156,12 +1140,6 @@ def get_ground_truth(self, index):
if dpt is not None: dpt[y: y + h, x: x + w, :]
if bkg is not None: bkg[y: y + h, x: x + w, :]

output.rgb = rgb.reshape(-1, 3) # full image in case you need it
output.msk = msk.reshape(-1, 1) # full mask
output.wet = wet.reshape(-1, 1) # full weights
if dpt is not None: output.dpt = dpt.reshape(-1, 1)
if bkg is not None: output.bkg = bkg.reshape(-1, 1)

# Crop the intrinsics
self.crop_ixts(output, x, y, w, h)

Expand All @@ -1186,12 +1164,6 @@ def get_ground_truth(self, index):

if should_sample_patch:
assert n_rays == -1, 'When performing patch sampling, do not resample rays on it'
# Prepare images for patch sampling
rgb = output.rgb.view(H, W, 3)
msk = output.msk.view(H, W, 1)
wet = output.wet.view(H, W, 1)
if dpt is not None: dpt = output.dpt.view(H, W, 1)
if bkg is not None: bkg = output.bkg.view(H, W, 3)

# Find the Xp Yp Wp Hp to be used for random patch sampling
# x = 0 if W - Wp <= 0 else np.random.randint(0, W - Wp + 1)
Expand All @@ -1215,11 +1187,11 @@ def get_ground_truth(self, index):
if dpt is not None: dpt = dpt[y: y + h, x: x + w, :]
if bkg is not None: bkg = bkg[y: y + h, x: x + w, :]

output.rgb = rgb.reshape(-1, 3) # full image in case you need it
output.msk = msk.reshape(-1, 1) # full mask
output.wet = wet.reshape(-1, 1) # full weights
if dpt is not None: output.dpt = dpt.reshape(-1, 1)
if bkg is not None: output.bkg = bkg.reshape(-1, 1)
output.rgb = rgb.reshape(-1, 3) # full image in case you need it
output.msk = msk.reshape(-1, 1) # full mask
output.wet = wet.reshape(-1, 1) # full weights
if dpt is not None: output.dpt = dpt.reshape(-1, 1)
if bkg is not None: output.bkg = bkg.reshape(-1, 1)

if should_crop_ixt:
# Prepare the resized ixts
Expand Down Expand Up @@ -1305,19 +1277,6 @@ def crop_ixts_bounds(output: dotdict):
x, y, w, h = get_bound_2d_bound(output.bounds, output.K, output.R, output.T, output.meta.H, output.meta.W)
return VolumetricVideoDataset.crop_ixts(output, x, y, w, h)

@staticmethod
def crop_imgs_bounds(output: dotdict):
"""
Crops target images using a xywh computed from a bounds and stored in the metadata
"""
x, y, w, h = output.crop_x, output.crop_y, output.W, output.H
H, W = output.orig_h, output.orig_w

output.rgb = output.rgb.view(H, W, -1)[y:y + h, x:x + w].reshape(-1, 3)
output.msk = output.msk.view(H, W, -1)[y:y + h, x:x + w].reshape(-1, 1)
output.wet = output.wet.view(H, W, -1)[y:y + h, x:x + w].reshape(-1, 1)
return output

def get_viewer_batch(self, output: dotdict):
# Source indices
t = output.t
Expand Down
4 changes: 2 additions & 2 deletions scripts/mvsnet/mvsnet_to_easyvolcap_depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
@catch_throw
def main():
args = dotdict()
args.mvsnet_dir = '../cvp-mvsnet/outputs_pretrained/0013_01/depth_est'
args.volcap_dir = 'data/renbody/0013_01/depths'
args.mvsnet_dir = '../cvp-mvsnet/outputs_pretrained/0008_01/depth_est'
args.volcap_dir = 'data/renbody/0008_01/depths'
args.image = '000000.jpg'
args.convert = True
args.convert_ext = '.exr'
Expand Down
14 changes: 8 additions & 6 deletions scripts/tools/depth_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
@catch_throw
def main():
args = dotdict()
args.data_root = 'data/renbody/0013_01'
args.data_root = 'data/renbody/0008_01'
args.depth_dir = 'depths'
args.depth = '000000.exr' # camera + postfix = args.depth file name
args.images_dir = 'images_calib'
Expand All @@ -40,8 +40,8 @@ def main():
names = sorted(os.listdir(join(args.data_root, args.depth_dir)))
cameras = dotdict({k: cameras[k] for k in names})

c2ws = torch.stack([torch.cat([cameras[k].R, cameras[k].T], dim=-1) for k in cameras]) # V, 4, 4
w2cs = affine_inverse(c2ws)
w2cs = torch.stack([torch.cat([cameras[k].R, cameras[k].T], dim=-1) for k in cameras]) # V, 4, 4
c2ws = affine_inverse(w2cs)
_, src_inds = compute_camera_similarity(c2ws, c2ws) # V, V

dpts = []
Expand All @@ -51,7 +51,7 @@ def main():
rgbs = []
Ks = []

for cam in tqdm(cameras, desc='Loading depths & images'):
for cam in tqdm(cameras, desc='Loading depths & images & rays'):
depth_file = join(args.data_root, args.depth_dir, cam, args.depth)
image_file = join(args.data_root, args.images_dir, cam, args.image)
rgb = to_cuda(to_tensor(load_image(image_file)).float()) # H, W, 3
Expand All @@ -64,8 +64,10 @@ def main():
K[0:1] *= int(W * args.ratio) / W
K[1:2] *= int(H * args.ratio) / H
H, W = int(H * args.ratio), int(W * args.ratio)
rgb = resize_image(rgb, size=(H, W))
dpt = resize_image(dpt, size=(H, W))
if rgb.shape[0] != H or rgb.shape[1] != W:
rgb = resize_image(rgb, size=(H, W))
if dpt.shape[0] != H or dpt.shape[1] != W:
dpt = resize_image(dpt, size=(H, W))

ray_o, ray_d = get_rays(H, W, K, R, T, z_depth=True)
dpts.append(dpt)
Expand Down
48 changes: 48 additions & 0 deletions scripts/tools/prediction_to_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
This file will convert network prediction format to the actual dataset format
Network prediction format:
data/<result_dir>/<exp_name>/<save_tag>/<type>:
frame<frame>_camera<camera><ext>
frame<frame>_camera<camera><ext>
Dataset format:
data/<dataset>/<sequence>/<data>:
<camera>/<frame><ext>
<camera>/<frame><ext>
"""

from glob import glob
from easyvolcap.utils.console_utils import *


@catch_throw
def main():
args = dotdict(
result_dir='result',
exp_name='gsmap_dtu_dpt_rgb_exr',
save_tag='0008_01_obj',
type='DEPTH',
ext='.exr',
dataset='renbody',
sequence='0008_01',
data='depths',
)

args = dotdict(vars(build_parser(args, description=__doc__).parse_args()))
pred_dir = join('data', args.result_dir, args.exp_name, args.save_tag, args.type)
data_dir = join('data', args.dataset, args.sequence, args.data)

files = glob(join(pred_dir, f'*{args.ext}'))
for f in tqdm(files):
frame_str_idx = f.index('frame') + len('frame')
camera_str_idx = f.index('camera') + len('camera')
frame = int(f[frame_str_idx:frame_str_idx + 4])
camera = int(f[camera_str_idx:camera_str_idx + 4])
dst = join(data_dir, f'{camera:02d}', f'{frame:06d}' + args.ext)
os.makedirs(dirname(dst), exist_ok=True)
shutil.copy(f, dst)


if __name__ == '__main__':
main()
3 changes: 0 additions & 3 deletions scripts/tools/volume_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument('--result_dir', type=str, default='data/geometry')
parser.add_argument('--n_srcs', type=int, default=4, help='Number of source views to use for the fusion process')
# parser.add_argument('--view_sample', nargs=3, default=[0, None, 1], type=int)
# parser.add_argument('--frame_sample', nargs=3, default=[0, None, 1], type=int)
parser.add_argument('--msk_abs_thresh', type=float, default=0.5, help='If mask exists, filter points with too low a mask value')
parser.add_argument('--geo_abs_thresh', type=float, default=1.0, help='The threshold for MSE in reprojection, unit: squared pixels') # aiming for a denser reconstruction
parser.add_argument('--geo_rel_thresh', type=float, default=0.01, help='The difference in relative depth values, unit: one')
Expand All @@ -63,7 +61,6 @@ def fuse(runner: "VolumetricVideoRunner", args: argparse.Namespace):

dataset = runner.val_dataloader.dataset
inds = get_inds(dataset)
# inds = get_inds(dataset, view_sample=args.view_sample, frame_sample=args.frame_sample)
nv, nl = inds.shape[:2]
prefix = 'frame'

Expand Down

0 comments on commit 3bfad2e

Please sign in to comment.