Skip to content

Commit

Permalink
Improves to Timer and Profiler utils (#628)
Browse files Browse the repository at this point in the history
* Timer utility supports labels and collecting timing histories.

* Added create stats and dump methods to Profiler.

* Fix bug and unit test for Timer.

* Add deactivate keyword to timer context and start method.
  • Loading branch information
luisenp authored Nov 30, 2023
1 parent 322807e commit 8b10e97
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 7 deletions.
43 changes: 43 additions & 0 deletions tests/theseus_tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,46 @@ def test_jacobians_check():

cf = th.eb.DoubleIntegrator(se3s[0], th.Vector(6), se3s[1], th.Vector(6), 1.0, w)
thutils.check_jacobians(cf, 1)


def test_timer():
# Check different ways of instantiating work correctly
with thutils.Timer("cpu") as timer:
torch.randn(1)
assert timer.elapsed_time > 0

with thutils.Timer("cpu", active=False) as timer:
torch.randn(1)
assert timer.elapsed_time == 0

timer = thutils.Timer("cpu")
with timer:
torch.randn(1)
assert timer.elapsed_time > 0

timer = thutils.Timer("cpu", active=False)
with timer:
torch.randn(1)
assert timer.elapsed_time == 0

# Checking that deactivate keyword works correctly
timer = thutils.Timer("cpu")
with timer("randn", deactivate=True):
torch.randn(1)
assert timer.elapsed_time == 0
timer.start("randn", deactivate=True)
torch.randn(1)
timer.end()
assert timer.elapsed_time == 0
# Checking that stats accumulation works correctly
with timer("randn"):
torch.randn(1)
timer.start("randn")
torch.randn(1)
timer.end()
with timer("mult"):
torch.ones(1) * torch.zeros(1)
stats = timer.stats()
assert "randn" in stats and "mult" in stats
assert len(stats["randn"]) == 2
assert len(stats["mult"]) == 1
74 changes: 67 additions & 7 deletions theseus/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import io
import pstats
import time
from typing import Any, Callable, List, Optional, Type
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Type

import numpy as np
import torch
Expand Down Expand Up @@ -188,8 +189,8 @@ def autograd_fn(*optim_var_tensors):
)


# A basic timer utility that adapts to the device. Useful for removing
# boilerplate code when benchmarking tasks.
# A timer utility that adapts to the device. Useful for removing
# boilerplate code when benchmarking tasks, and collect statistics.
# For CPU it uses time.perf_counter_ns()
# For GPU it uses torch.cuda.Event()
#
Expand All @@ -200,27 +201,78 @@ def autograd_fn(*optim_var_tensors):
# with Timer("cuda:0") as timer:
# do_some_stuff()
# print(timer.elapsed_time)
##
# timer = Timer("cuda:0")
# timer.start()
# do_some_stuff()
# timer.end()
# print(timer.elapsed_time)
#
# The timer can also optionally collect history of times in a dictionary
# by adding caller ids to each context or start() calls.
# For example,
#
# timer.start("f1")
# f1()
# timer.end()
# with timer("f2"):
# f2()
# print(timer.stats())
#
# The timer can also be set inactive either via constructor
# (i.e., Timer(device, active=False))
# or via start method (i.e., timer.start(call_id, deactivate=True)).
class Timer:
def __init__(self, device: th.DeviceType) -> None:
def __init__(self, device: th.DeviceType, active: bool = True) -> None:
self.active = active
self.device = torch.device(device)
self.elapsed_time = 0.0

def __enter__(self) -> "Timer":
self._stats: Dict[str, List[float]] = defaultdict(list)
self._caller: Optional[str] = None
self._tmp_deactivated = False

def start(self, caller: Optional[str] = None, deactivate: bool = False) -> "Timer":
if not self.active or deactivate:
self._tmp_deactivated = deactivate
return self
if self.device.type == "cuda":
self._start_event = torch.cuda.Event(enable_timing=True)
self._end_event = torch.cuda.Event(enable_timing=True)
self._start_event.record()
else:
self._start_time = time.perf_counter_ns()
self._caller = caller
return self

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
def end(self) -> None:
if not self.active or self._tmp_deactivated:
self._tmp_deactivated = False
return
if self.device.type == "cuda":
self._end_event.record()
torch.cuda.synchronize()
self.elapsed_time = self._start_event.elapsed_time(self._end_event) / 1e3
else:
self.elapsed_time = (time.perf_counter_ns() - self._start_time) / 1e9
if self._caller is not None:
self._stats[self._caller].append(self.elapsed_time)

def __call__(
self, caller: Optional[str] = None, deactivate: bool = False
) -> "Timer":
self._caller = caller
self._tmp_deactivated = deactivate
return self

def __enter__(self) -> "Timer":
self.start(caller=self._caller, deactivate=self._tmp_deactivated)
return self

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.end()

def stats(self) -> Dict[str, List[float]]:
return self._stats


# Wrapper for cProfile.Profile for easily make optional, turn on/off and printing
Expand All @@ -244,3 +296,11 @@ def print(self):
ps = pstats.Stats(self.c_profiler, stream=s).sort_stats(sortby)
ps.print_stats()
print(s.getvalue())

def create_stats(self):
if self.active:
self.c_profiler.create_stats()

def dump_stats(self, filename: str):
if self.active:
self.c_profiler.dump_stats(filename)

0 comments on commit 8b10e97

Please sign in to comment.