diff --git a/tests/theseus_tests/utils/test_utils.py b/tests/theseus_tests/utils/test_utils.py index 34b50ca03..2b0b2fcdf 100644 --- a/tests/theseus_tests/utils/test_utils.py +++ b/tests/theseus_tests/utils/test_utils.py @@ -168,3 +168,46 @@ def test_jacobians_check(): cf = th.eb.DoubleIntegrator(se3s[0], th.Vector(6), se3s[1], th.Vector(6), 1.0, w) thutils.check_jacobians(cf, 1) + + +def test_timer(): + # Check different ways of instantiating work correctly + with thutils.Timer("cpu") as timer: + torch.randn(1) + assert timer.elapsed_time > 0 + + with thutils.Timer("cpu", active=False) as timer: + torch.randn(1) + assert timer.elapsed_time == 0 + + timer = thutils.Timer("cpu") + with timer: + torch.randn(1) + assert timer.elapsed_time > 0 + + timer = thutils.Timer("cpu", active=False) + with timer: + torch.randn(1) + assert timer.elapsed_time == 0 + + # Checking that deactivate keyword works correctly + timer = thutils.Timer("cpu") + with timer("randn", deactivate=True): + torch.randn(1) + assert timer.elapsed_time == 0 + timer.start("randn", deactivate=True) + torch.randn(1) + timer.end() + assert timer.elapsed_time == 0 + # Checking that stats accumulation works correctly + with timer("randn"): + torch.randn(1) + timer.start("randn") + torch.randn(1) + timer.end() + with timer("mult"): + torch.ones(1) * torch.zeros(1) + stats = timer.stats() + assert "randn" in stats and "mult" in stats + assert len(stats["randn"]) == 2 + assert len(stats["mult"]) == 1 diff --git a/theseus/utils/utils.py b/theseus/utils/utils.py index 5537612fb..af74c496d 100644 --- a/theseus/utils/utils.py +++ b/theseus/utils/utils.py @@ -6,7 +6,8 @@ import io import pstats import time -from typing import Any, Callable, List, Optional, Type +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional, Type import numpy as np import torch @@ -188,8 +189,8 @@ def autograd_fn(*optim_var_tensors): ) -# A basic timer utility that adapts to the device. Useful for removing -# boilerplate code when benchmarking tasks. +# A timer utility that adapts to the device. Useful for removing +# boilerplate code when benchmarking tasks, and collect statistics. # For CPU it uses time.perf_counter_ns() # For GPU it uses torch.cuda.Event() # @@ -200,27 +201,78 @@ def autograd_fn(*optim_var_tensors): # with Timer("cuda:0") as timer: # do_some_stuff() # print(timer.elapsed_time) +## +# timer = Timer("cuda:0") +# timer.start() +# do_some_stuff() +# timer.end() +# print(timer.elapsed_time) +# +# The timer can also optionally collect history of times in a dictionary +# by adding caller ids to each context or start() calls. +# For example, +# +# timer.start("f1") +# f1() +# timer.end() +# with timer("f2"): +# f2() +# print(timer.stats()) +# +# The timer can also be set inactive either via constructor +# (i.e., Timer(device, active=False)) +# or via start method (i.e., timer.start(call_id, deactivate=True)). class Timer: - def __init__(self, device: th.DeviceType) -> None: + def __init__(self, device: th.DeviceType, active: bool = True) -> None: + self.active = active self.device = torch.device(device) self.elapsed_time = 0.0 - - def __enter__(self) -> "Timer": + self._stats: Dict[str, List[float]] = defaultdict(list) + self._caller: Optional[str] = None + self._tmp_deactivated = False + + def start(self, caller: Optional[str] = None, deactivate: bool = False) -> "Timer": + if not self.active or deactivate: + self._tmp_deactivated = deactivate + return self if self.device.type == "cuda": self._start_event = torch.cuda.Event(enable_timing=True) self._end_event = torch.cuda.Event(enable_timing=True) self._start_event.record() else: self._start_time = time.perf_counter_ns() + self._caller = caller return self - def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + def end(self) -> None: + if not self.active or self._tmp_deactivated: + self._tmp_deactivated = False + return if self.device.type == "cuda": self._end_event.record() torch.cuda.synchronize() self.elapsed_time = self._start_event.elapsed_time(self._end_event) / 1e3 else: self.elapsed_time = (time.perf_counter_ns() - self._start_time) / 1e9 + if self._caller is not None: + self._stats[self._caller].append(self.elapsed_time) + + def __call__( + self, caller: Optional[str] = None, deactivate: bool = False + ) -> "Timer": + self._caller = caller + self._tmp_deactivated = deactivate + return self + + def __enter__(self) -> "Timer": + self.start(caller=self._caller, deactivate=self._tmp_deactivated) + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.end() + + def stats(self) -> Dict[str, List[float]]: + return self._stats # Wrapper for cProfile.Profile for easily make optional, turn on/off and printing @@ -244,3 +296,11 @@ def print(self): ps = pstats.Stats(self.c_profiler, stream=s).sort_stats(sortby) ps.print_stats() print(s.getvalue()) + + def create_stats(self): + if self.active: + self.c_profiler.create_stats() + + def dump_stats(self, filename: str): + if self.active: + self.c_profiler.dump_stats(filename)