Skip to content

Commit

Permalink
easyvolcap: better colorization
Browse files Browse the repository at this point in the history
  • Loading branch information
dendenxu committed Apr 5, 2024
1 parent c6445ec commit 0ff0732
Showing 1 changed file with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions scripts/points/ibr_colorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from easyvolcap.utils.easy_utils import read_camera, write_camera
from easyvolcap.utils.chunk_utils import multi_gather, multi_scatter
from easyvolcap.utils.data_utils import load_pts, export_pts, to_cuda, load_image
from easyvolcap.utils.image_utils import fill_nhwc_image
from easyvolcap.utils.ibr_utils import sample_geometry_feature_image


Expand All @@ -33,6 +34,7 @@ def main():
input='surfs6k',
output='surfs6k',
n_srcs=3,
ratio=0.25, # smaller ratio for blurrier points
sequential_image_loading=False,

# TODO: Use IBR models for this, for now, they require voxel feature input thus cannot be easily separated
Expand All @@ -49,9 +51,10 @@ def main():
camera_names = sorted(cameras)
cameras = dotdict({cam: cameras[cam] for cam in camera_names})
batches = dotdict({cam: to_cuda(Camera().from_easymocap(cameras[cam]).to_batch()) for cam in camera_names})
Rs = torch.stack([batches[cam].R for cam in batches]) # V,
Ts = torch.stack([batches[cam].T for cam in batches]) # V,
Ks = torch.stack([batches[cam].K for cam in batches]) # V,
Rs = torch.stack([batches[cam].R for cam in batches]) # V, 3, 3
Ts = torch.stack([batches[cam].T for cam in batches]) # V, 3, 1
Ks = torch.stack([batches[cam].K for cam in batches]) # V, 3, 3
Ks[..., :2, :] *= args.ratio
src_ixts = Ks # V, 3, 3
src_exts = torch.cat([Rs, Ts], dim=-1) # V, 3, 4

Expand All @@ -63,21 +66,29 @@ def main():
xyz = load_pts(join(args.data_root, args.input, file))[0]
xyz = to_cuda(xyz) # N, 3

# Load all images for this frame
log(f'Loading images for frame {idx}')
imgs = parallel_execution([join(args.data_root, args.images_dir, cam, images[idx]) for cam in camera_names], ratio=args.ratio, action=load_image, sequential=args.sequential_image_loading)
imgs = to_cuda(imgs) # V, Hx, Wx, 3
Hs = torch.as_tensor([img.shape[-3] for img in imgs], device=xyz.device) # V,
Ws = torch.as_tensor([img.shape[-2] for img in imgs], device=xyz.device) # V,
mh, mw = max(Hs), max(Ws)
imgs = [fill_nhwc_image(img, [mh, mw]) for img in imgs]
imgs = torch.stack(imgs).permute(0, 3, 1, 2) # V, H, W, 3 -> V, 3, H, W

# Perform the actual depth ranking depth.argsort() should have the same value as ranking.argsort()
log(f'Depth ranking for frame {idx}')
depth = (xyz[None] @ Rs.mT + Ts.mT)[..., -1] # V, N
argsort = depth.argsort(dim=-1) # V, N
depth = (xyz[None] @ Rs.mT + Ts.mT)[..., -1] # V, N
scr = (xyz[None] @ Rs.mT + Ts.mT) @ Ks.mT # V, N, 3
scr = scr[..., :2] / scr[..., 2:] # V, N, 2
depth = torch.where((scr[..., 0] < 0) | (scr[..., 0] > Ws[..., None]) | (scr[..., 1] < 0) | (scr[..., 1] > Hs[..., None]), torch.inf, depth)
argsort = depth.argsort(dim=-1) # V, N
rankings = torch.empty_like(argsort).scatter_(dim=-1, index=argsort, src=torch.arange(argsort.shape[-1], device=xyz.device)[None].repeat(len(depth), 1))

# Find the best k source views
src_inds = rankings.topk(args.n_srcs, dim=0, largest=False).indices # S, N
src_inds = src_inds.mT # N, S

# Load all images for this frame
log(f'Loading images for frame {idx}')
imgs = parallel_execution([join(args.data_root, args.images_dir, cam, images[idx]) for cam in camera_names], action=load_image, sequential=args.sequential_image_loading)
imgs = torch.stack(to_cuda(imgs)).permute(0, 3, 1, 2) # V, H, W, 3 -> V, 3, H, W

# Sample RGB color from them
log(f'Sampling rgb for frame {idx}')
rgbs = sample_geometry_feature_image(xyz[None], imgs[None], src_exts[None], src_ixts[None], torch.ones(2, 1, device=xyz.device))[0] # V, N, 3
Expand Down

0 comments on commit 0ff0732

Please sign in to comment.