Skip to content

Commit 23b786b

Browse files
committed
easyvolcap: adding recording for timer and visualizer & other qol updates
1 parent de20b6d commit 23b786b

File tree

13 files changed

+138
-32
lines changed

13 files changed

+138
-32
lines changed

configs/datasets/NHR/NHR.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ dataloader_cfg: # we see the term "dataloader" as one word?
55
vhull_thresh: 0.95
66
count_thresh: 6 # common views
77

8+
use_aligned_cameras: True
9+
810
vhull_thresh_factor: 0.75
911
vhull_count_factor: 1.0
1012

configs/datasets/mobile_stage/mobile_stage.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ dataloader_cfg: # we see the term "dataloader" as one word?
88
view_sample: [0, null, 1]
99
frame_sample: [0, null, 1] # only train for a thousand frames
1010

11+
use_aligned_cameras: True
12+
1113
vhull_thresh: 0.85 # 21 cameras?
1214
count_thresh: 6 # more visibility
1315
vhull_thresh_factor: 0.9 # FIXME: 313 need 1.5, 390, 394 requires 1.0

configs/datasets/my_zjumocap/my_zjumocap.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ dataloader_cfg: # we see the term "dataloader" as one word?
55
view_sample: [0, null, 1]
66
frame_sample: [0, 200, 1] # only train for a thousand frames
77

8+
use_aligned_cameras: True
9+
810
# MARK: This is for now the best vhull extraction setting
911
vhull_thresh: 0.95 # 18 cameras?
1012
count_thresh: 16 # common views

easyvolcap/engine/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def parse_cfg(args):
115115
)
116116
) # empty config
117117
else:
118-
raise FileNotFoundError(f"Config file {blue(args.config)} not found")
118+
raise FileNotFoundError(f"Config file {args.config} not found")
119+
# raise FileNotFoundError(f"Config file {markup_to_ansi(blue(args.config))} not found")
119120

120121

121122
parser = get_parser()

easyvolcap/models/samplers/gaussiant_sampler.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121

2222
from easyvolcap.utils.console_utils import *
2323
from easyvolcap.utils.console_utils import dotdict
24-
from easyvolcap.utils.gaussian_utils import GaussianModel
25-
from easyvolcap.utils.data_utils import load_pts, export_pts, to_x, to_cuda, to_cpu, to_tensor, remove_batch
26-
from easyvolcap.utils.net_utils import normalize, typed, update_optimizer_state
27-
from easyvolcap.utils.chunk_utils import multi_gather, multi_scatter
2824
from easyvolcap.utils.bound_utils import get_bounds
25+
from easyvolcap.utils.chunk_utils import multi_gather, multi_scatter
26+
from easyvolcap.utils.gaussian_utils import GaussianModel, in_frustrum
27+
from easyvolcap.utils.net_utils import normalize, typed, update_optimizer_state
28+
from easyvolcap.utils.data_utils import load_pts, export_pts, to_x, to_cuda, to_cpu, to_tensor, remove_batch
2929

3030
from easyvolcap.models.cameras.optimizable_camera import OptimizableCamera
3131
from easyvolcap.models.samplers.point_planes_sampler import PointPlanesSampler
@@ -128,6 +128,9 @@ def render_gaussians(self, xyz: torch.Tensor, sh: torch.Tensor, scale3: torch.Te
128128
# Prepare the camera transformation for Gaussian
129129
gaussian_camera = to_x(prepare_gaussian_camera(batch), torch.float)
130130

131+
# is_in_frustrum = in_frustrum(xyz, gaussian_camera.full_proj_transform)
132+
# print('Number of points to render:', is_in_frustrum.sum().item())
133+
131134
# Prepare rasterization settings for gaussian
132135
raster_settings = GaussianRasterizationSettings(
133136
image_height=gaussian_camera.image_height,

easyvolcap/runners/evaluators/volumetric_video_evaluator.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,17 @@ def evaluate(self, output: dotdict, batch: dotdict):
3838
for compute in self.compute_metrics:
3939
metrics[compute.__name__] = compute(img, img_gt) # actual computation of the metrics
4040

41-
self.metrics.append(metrics)
41+
if len(metrics):
42+
self.metrics.append(metrics)
4243

43-
# For recording
44-
c = batch.meta.camera_index.item()
45-
f = batch.meta.frame_index.item()
46-
log(f'camera: {c}', f'frame: {f}', metrics)
47-
metrics.camera = c
48-
metrics.frame = f
49-
scalar_stats = dotdict({f'{k}_frame{f:04d}_cam{c:04d}': v for k, v in metrics.items()})
44+
# For recording
45+
c = batch.meta.camera_index.item()
46+
f = batch.meta.frame_index.item()
47+
log(f'camera: {c}', f'frame: {f}', metrics)
48+
metrics.camera = c
49+
metrics.frame = f
5050

51+
scalar_stats = dotdict({f'{k}_frame{f:04d}_cam{c:04d}': v for k, v in metrics.items()})
5152
return scalar_stats
5253

5354
def summarize(self):

easyvolcap/runners/visualizers/volumetric_video_visualizer.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,13 @@ def __init__(self,
3535
Visualization.ALPHA.name,
3636
],
3737

38-
stream_delay: int = 5, # after this number of pending copy, start synchronizing the stream and saving to disk
39-
pool_limit: int = 5, # maximum number of pending tasks in the thread pool, keep this small to avoid using too much resource
38+
stream_delay: int = 2, # after this number of pending copy, start synchronizing the stream and saving to disk
39+
pool_limit: int = 10, # maximum number of pending tasks in the thread pool, keep this small to avoid using too much resource
4040
video_fps: int = 60,
4141
verbose: bool = True,
4242

4343
dpt_curve: str = 'normalize', # looks good
44+
dpt_mult: float = 1.0,
4445
dpt_cm: str = 'virdis' if args.type != 'gui' else 'linear', # looks good
4546
):
4647
super().__init__()
@@ -71,6 +72,7 @@ def __init__(self,
7172
self.video_fps = video_fps
7273
self.verbose = verbose
7374
self.dpt_curve = dpt_curve
75+
self.dpt_mult = dpt_mult
7476
self.dpt_cm = dpt_cm
7577

7678
if self.verbose:
@@ -102,11 +104,15 @@ def norm_curve_fn(norm):
102104
img = output.dpt_map
103105
else:
104106
img = depth_curve_fn(output.dpt_map, cm=self.dpt_cm)
107+
# img = (img - 0.5) * self.dpt_mult + 0.5
108+
img = img * self.dpt_mult
105109
if self.store_ground_truth and 'dpt' in batch:
106110
if self.dpt_curve == 'linear':
107111
img_gt = batch.dpt
108112
else:
109113
img_gt = depth_curve_fn(batch.dpt, cm=self.dpt_cm)
114+
# img_gt = (img_gt - 0.5) * self.dpt_mult + 0.5
115+
img_gt = img_gt * self.dpt_mult
110116

111117
elif type == Visualization.FEATURE:
112118
# This visualizes the xyzt + xyz feature output

easyvolcap/runners/volumetric_video_runner.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ def __init__(self,
8383

8484
# Debugging
8585
collect_timing: bool = False, # will lose 1 fps over copying
86-
timer_sync_cuda: bool = True,
86+
timer_sync_cuda: bool = True, # will explicitly call torch.cuda.synchronize() before collecting
87+
timer_record_to_file: bool = False, # will write to a json file for collected analysis of the timing
8788
):
8889
self.model = model # possibly already a ddp model?
8990

@@ -148,6 +149,7 @@ def __init__(self,
148149
# Debugging
149150
self.collect_timing = collect_timing # another fancy self.timer (different from fps counter)
150151
self.timer_sync_cuda = timer_sync_cuda # this enables accurate time recording for each section, but would slow down the programs
152+
self.timer_record_to_file = timer_record_to_file
151153

152154
@property
153155
def collect_timing(self):
@@ -157,6 +159,10 @@ def collect_timing(self):
157159
def timer_sync_cuda(self):
158160
return timer.sync_cuda
159161

162+
@property
163+
def timer_record_to_file(self):
164+
return timer.record_to_file
165+
160166
@collect_timing.setter
161167
def collect_timing(self, val: bool):
162168
timer.disabled = not val
@@ -165,6 +171,16 @@ def collect_timing(self, val: bool):
165171
def timer_sync_cuda(self, val: bool):
166172
timer.sync_cuda = val
167173

174+
@timer_record_to_file.setter
175+
def timer_record_to_file(self, val: bool):
176+
timer.record_to_file = val
177+
if timer.record_to_file:
178+
log(yellow(f'Will record timing results to {blue(join(self.recorder.record_dir, f"{self.exp_name}.json"))}'))
179+
timer.exp_name = self.exp_name
180+
timer.record_dir = self.recorder.record_dir
181+
if not hasattr(timer, 'timing_record'):
182+
timer.timing_record = dotdict()
183+
168184
@property
169185
def total_iter(self):
170186
return self.epochs * self.ep_iter

easyvolcap/runners/volumetric_video_viewer.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(self,
7171
update_mem_time: float = 0.1, # be less stressful
7272
use_quad_draw: bool = False, # different rendering solution
7373
use_quad_cuda: bool = True,
74+
use_vsync: bool = False,
7475

7576
# This is important for works like K-planes or IBR (or stableenerf), since it's not easy to perform interpolation (slow motion)
7677
# For point clouds, only a fixed number of point clouds are produces since we performed discrete training (no interpolation)
@@ -99,6 +100,7 @@ def __init__(self,
99100
self.fullscreen = fullscreen
100101
self.window_size = window_size
101102
self.window_title = window_title
103+
self.use_vsync = use_vsync
102104
self.use_window_focal = use_window_focal
103105

104106
# Quad related configurations
@@ -901,7 +903,15 @@ def draw_banner_gui(self, batch: dotdict = dotdict(), output: dotdict = dotdict(
901903
imgui.pop_font()
902904

903905
# Full frame timings
904-
timer.disabled = not imgui_toggle.toggle('Collect timings', not timer.disabled, config=self.static.toggle_ios_style)[1]
906+
self.runner.collect_timing = imgui_toggle.toggle('Collect timing', self.runner.collect_timing, config=self.static.toggle_ios_style)[1]
907+
changed, value = imgui_toggle.toggle('Record timing', self.runner.timer_record_to_file, config=self.static.toggle_ios_style)
908+
if changed:
909+
self.runner.timer_record_to_file = value
910+
self.runner.timer_sync_cuda = imgui_toggle.toggle('Sync timing', self.runner.timer_sync_cuda, config=self.static.toggle_ios_style)[1]
911+
changed, self.use_vsync = imgui_toggle.toggle('Enable VSync', self.use_vsync, config=self.static.toggle_ios_style)
912+
if changed:
913+
glfw.swap_interval(self.use_vsync)
914+
905915
if not timer.disabled:
906916
if imgui.collapsing_header('Timing'):
907917
imgui.text(f'gui : {batch.gui_time * 1000:7.3f}ms')
@@ -1417,7 +1427,7 @@ def init_glfw(self):
14171427
# Create a windowed mode window and its OpenGL context
14181428
window = glfw.create_window(self.W, self.H, self.window_title, None, None)
14191429
glfw.make_context_current(window)
1420-
glfw.swap_interval(0) # disable vsync
1430+
glfw.swap_interval(self.use_vsync) # disable vsync
14211431

14221432
icon = load_image(self.icon_file)
14231433
pixels = (icon * 255).astype(np.uint8)

easyvolcap/utils/console_utils.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -599,13 +599,25 @@ def wrapper(func: Callable):
599599

600600

601601
class Timer:
602-
def __init__(self, name='', disabled: bool = False, sync_cuda: bool = True):
602+
def __init__(self,
603+
name='base',
604+
exp_name='',
605+
record_dir: str = 'data/timing',
606+
disabled: bool = False,
607+
sync_cuda: bool = True,
608+
record_to_file: bool = False,
609+
):
603610
self.sync_cuda = sync_cuda
604611
self.disabled = disabled
605612
self.name = name
613+
self.exp_name = exp_name
606614
self.start_time = time.perf_counter() # manually record another start time incase timer is disabled during initialization
607615
self.start() # you can always restart multiple times to reuse this timer
608616

617+
self.record_to_file = record_to_file
618+
if self.record_to_file:
619+
self.timing_record = dotdict()
620+
609621
def __enter__(self):
610622
self.start()
611623

@@ -636,6 +648,13 @@ def record(self, event: str = ''):
636648
if self.disabled: return 0
637649
self.name = event
638650
diff = self.stop(print=bool(event), back=3)
651+
if self.record_to_file and event:
652+
if event not in self.timing_record:
653+
self.timing_record[event] = []
654+
self.timing_record[event].append(diff)
655+
656+
with open(join(self.record_dir, f'{self.exp_name}.json'), 'w') as f:
657+
json.dump(self.timing_record, f, indent=4)
639658
self.start()
640659
return diff
641660

easyvolcap/utils/data_utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def video_to_numpy(input_filename):
231231
'ffmpeg',
232232
'-hwaccel', 'cuda',
233233
'-v', 'quiet', '-stats',
234-
# '-vcodec', 'hevc_cuvid',
234+
'-vcodec', 'hevc_cuvid',
235235
'-i', input_filename,
236236
'-f', 'image2pipe',
237237
'-pix_fmt', 'rgb24',
@@ -244,7 +244,8 @@ def video_to_numpy(input_filename):
244244

245245
# Convert the raw data to numpy array and reshape
246246
video_np = np.frombuffer(raw_data, dtype=np.uint8)
247-
video_np = video_np.reshape(-1, H, W, 3)
247+
H2, W2 = (H + 1) // 2 * 2, (W + 1) // 2 * 2
248+
video_np = video_np.reshape(-1, H2, W2, 3)[:, :H, :W, :]
248249
return video_np
249250

250251

easyvolcap/utils/gaussian_utils.py

+39-3
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,41 @@
88
from easyvolcap.utils.console_utils import *
99
from easyvolcap.utils.sh_utils import eval_sh
1010
from easyvolcap.utils.blend_utils import batch_rodrigues
11-
from easyvolcap.utils.math_utils import torch_inverse_2x2
1211
from easyvolcap.utils.data_utils import to_x, add_batch, load_pts
1312
from easyvolcap.utils.net_utils import make_buffer, make_params, typed
13+
from easyvolcap.utils.math_utils import torch_inverse_2x2, point_padding
14+
15+
16+
# def in_frustrum(xyz: torch.Tensor, ixt: torch.Tensor, ext: torch.Tensor):
17+
def in_frustrum(xyz: torch.Tensor, full_proj_matrix: torch.Tensor, padding: float = 0.01):
18+
# __forceinline__ __device__ bool in_frustum(int idx,
19+
# const float* orig_points,
20+
# const float* viewmatrix,
21+
# const float* projmatrix,
22+
# bool prefiltered,
23+
# float3& p_view,
24+
# const float padding = 0.01f // padding in ndc space
25+
# )
26+
# {
27+
# float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] };
28+
29+
# // Bring points to screen space
30+
# float4 p_hom = transformPoint4x4(p_orig, projmatrix);
31+
# float p_w = 1.0f / (p_hom.w + 0.0000001f);
32+
# float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w };
33+
# p_view = transformPoint4x3(p_orig, viewmatrix); // write this outside
34+
35+
# // if (idx % 32768 == 0) printf("Viewspace point: %f, %f, %f\n", p_view.x, p_view.y, p_view.z);
36+
# // if (idx % 32768 == 0) printf("Projected point: %f, %f, %f\n", p_proj.x, p_proj.y, p_proj.z);
37+
# return (p_proj.z > -1 - padding) && (p_proj.z < 1 + padding) && (p_proj.x > -1 - padding) && (p_proj.x < 1. + padding) && (p_proj.y > -1 - padding) && (p_proj.y < 1. + padding);
38+
# }
39+
40+
# xyz: N, 3
41+
# ndc = (xyz @ R.mT + T)[..., :3] @ K # N, 3
42+
# ndc[..., :2] = ndc[..., :2] / ndc[..., 2:] / torch.as_tensor([W, H], device=ndc.device) # N, 2, normalized x and y
43+
ndc = point_padding(xyz) @ full_proj_matrix
44+
ndc = ndc[..., :3] / ndc[..., 3:]
45+
return (ndc[..., 2] > -1 - padding) & (ndc[..., 2] < 1 + padding) & (ndc[..., 0] > -1 - padding) & (ndc[..., 0] < 1. + padding) & (ndc[..., 1] > -1 - padding) & (ndc[..., 1] < 1. + padding) # N,
1446

1547

1648
@torch.jit.script
@@ -199,7 +231,8 @@ def prepare_gaussian_camera(batch):
199231
def convert_to_gaussian_camera(K: torch.Tensor,
200232
R: torch.Tensor,
201233
T: torch.Tensor,
202-
H: int, W: int,
234+
H: int,
235+
W: int,
203236
znear: float = 0.01,
204237
zfar: float = 100.
205238
):
@@ -220,7 +253,7 @@ def convert_to_gaussian_camera(K: torch.Tensor,
220253

221254
output.world_view_transform = getWorld2View(output.R, output.T).transpose(0, 1)
222255
output.projection_matrix = getProjectionMatrix(output.K, output.image_height, output.image_width, znear, zfar).transpose(0, 1)
223-
output.full_proj_transform = torch.matmul(output.world_view_transform, output.projection_matrix)
256+
output.full_proj_transform = torch.matmul(output.world_view_transform, output.projection_matrix) # 4, 4
224257
output.camera_center = output.world_view_transform.inverse()[3:, :3]
225258

226259
# Set up rasterization configuration
@@ -686,6 +719,9 @@ def render(self, batch: dotdict):
686719
# Prepare the camera transformation for Gaussian
687720
gaussian_camera = to_x(prepare_gaussian_camera(batch), torch.float)
688721

722+
# is_in_frustrum = in_frustrum(xyz, gaussian_camera.full_proj_transform)
723+
# print('Number of points to render:', is_in_frustrum.sum().item())
724+
689725
# Prepare rasterization settings for gaussian
690726
raster_settings = GaussianRasterizationSettings(
691727
image_height=gaussian_camera.image_height,

scripts/tools/runtime_as_ply.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,17 @@
1717
def main():
1818
# fmt: off
1919
import sys
20-
sys.path.append('.')
2120

22-
sep_ind = sys.argv.index('--')
21+
sep_ind = sys.argv.index('--') if '--' in sys.argv else len(sys.argv)
2322
our_args = sys.argv[1:sep_ind]
2423
evv_args = sys.argv[sep_ind + 1:]
25-
sys.argv = [sys.argv[0]] + ['-t','test'] + evv_args + ['val_dataloader_cfg.dataset_cfg.type=VolumetricVideoDataset'] # use default dataset
24+
sys.argv = [sys.argv[0]] + ['-t','test'] + evv_args
2625

27-
parser = argparse.ArgumentParser()
28-
parser.add_argument('--result_dir', type=str, default='data/geometry')
29-
parser.add_argument('--frame_index', type=int, default=0)
30-
parser.add_argument('--skip_align', action='store_true')
31-
args = parser.parse_args(our_args)
26+
args = dotdict()
27+
args.result_dir = 'data/geometry'
28+
args.frame_index = 0
29+
args.skip_align = False
30+
args =dotdict(vars(build_parser(args).parse_args(our_args)))
3231

3332
sys.argv += [f'val_dataloader_cfg.dataset_cfg.frame_sample={args.frame_index},{args.frame_index+1},1']
3433

@@ -47,11 +46,19 @@ def main():
4746
special_mapping = {
4847
f'sampler.pcds.{args.frame_index}': 'pts',
4948
f'sampler.rgbs.{args.frame_index}': 'color',
49+
f'sampler.bg_sampler.pcds.{args.frame_index}': 'pts',
50+
f'sampler.bg_sampler.rgbs.{args.frame_index}': 'color',
51+
f'sampler.fg_sampler.pcds.{args.frame_index}': 'pts',
52+
f'sampler.fg_sampler.rgbs.{args.frame_index}': 'color',
5053
}
5154

5255
named_mapping = {
5356
f'sampler.rads.{args.frame_index}': 'radius',
5457
f'sampler.occs.{args.frame_index}': 'alpha',
58+
f'sampler.bg_sampler.rads.{args.frame_index}': 'radius',
59+
f'sampler.bg_sampler.occs.{args.frame_index}': 'alpha',
60+
f'sampler.fg_sampler.rads.{args.frame_index}': 'radius',
61+
f'sampler.fg_sampler.occs.{args.frame_index}': 'alpha',
5562
}
5663

5764
# Save the model's registered parameters as numpy arrays in npz

0 commit comments

Comments
 (0)