Skip to content

Commit

Permalink
Faster log map implementation (facebookresearch#629)
Browse files Browse the repository at this point in the history
* Add a passthrough autograd operator for SO3 log map.

* Add private param to allow passthrough ops if available (default=False).

* Add a base UnaryPassthroughFn class.

* Add passthrough operator for SE3.

* Add an evaluation script for local cost differentiation timings.

* Reduce number of mults in SO3 log map.

* Improvements to timing script.

* Add other options to compute sine axis for SO3 log map.

* Lie group op tests can run on CUDA if available.

* Add a separate _faster_log_maps global param.

* Add unit tests for passthrough ops and fix some bugs.

* Fix torch.cuda.is_available call bug.

* Add forward pass only measurements to timing script.

* Add theseus option for fast approximate log maps.

* Add verbosity level to timing script.

* Update version numbers.

* Add test for SO3 sine_axis function.

* Rename fast_approx_log_maps as fast_approx_local_jacobians.
  • Loading branch information
luisenp authored Nov 30, 2023
1 parent 8b10e97 commit c68c7f5
Show file tree
Hide file tree
Showing 14 changed files with 302 additions and 33 deletions.
100 changes: 100 additions & 0 deletions evaluations/time_local_cost_backward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import argparse

import numpy as np
import theseus as th
import torch
import tqdm
import torchlie.functional as lieF
from theseus.global_params import set_global_params as set_th_global_params
from torchlie.functional.lie_group import LieGroupFns
from torchlie.global_params import set_global_params as set_lie_global_params
from theseus.utils import Timer


def run(
backward: bool,
group_type: str,
dev: str,
batch_size: int,
rng: torch.Generator,
verbose_level: int,
timer: Timer,
timer_label: str,
):
theseus_cls = getattr(th, group_type)
lieF_cls: LieGroupFns = getattr(lieF, group_type)
p = torch.nn.Parameter(lieF_cls.rand(batch_size, device=dev, generator=rng))
adam = torch.optim.Adam([p], lr={"SO3": 0.1, "SE3": 0.01}[group_type])
a = theseus_cls(name="a")
b = theseus_cls(
tensor=lieF_cls.rand(batch_size, device=dev, generator=rng), name="b"
)
o = th.Objective()
o.add(th.Local(a, b, th.ScaleCostWeight(1.0), name="d"))
layer = th.TheseusLayer(th.LevenbergMarquardt(o, max_iterations=3, step_size=0.1))
layer.to(dev)
timer.start(timer_label)
for i in range(10):

def _do():
layer.forward(
input_tensors={"a": p.clone()},
optimizer_kwargs={"damping": 0.1, "verbose": verbose_level > 1},
)

if backward:
adam.zero_grad()
_do()
loss = o.error_metric().sum()
if verbose_level > 0:
print(loss.item())
loss.backward()
adam.step()
else:
with torch.no_grad():
_do()
timer.end()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("reps", type=int, default=1)
parser.add_argument("--b", type=int, default=1, help="batch size")
parser.add_argument("--v", type=int, help="verbosity_level", default=0)
parser.add_argument("--g", choices=["SO3", "SE3"], default="SE3", help="group type")
parser.add_argument("--w", type=int, default=1, help="warmup iters")
parser.add_argument("--dev", type=str, default="cpu", help="device")
args = parser.parse_args()

rng = torch.Generator(device=args.dev)
rng.manual_seed(0)
timer = Timer(args.dev)
print(f"Timing device {timer.device}")

for backward in [True, False]:
for p in [True, False]:
label = f"b{backward:1d}-p{p:1d}"
set_lie_global_params({"_allow_passthrough_ops": p})
set_lie_global_params({"_faster_log_maps": p})
set_th_global_params({"fast_approx_local_jacobians": p})
for i in tqdm.tqdm(range(args.reps + args.w)):
run(
backward,
args.g,
args.dev,
args.b,
rng,
args.v,
timer,
f"run-{label}" if i > args.w else f"warmup-{label}",
)
time_stats = timer.stats()
results = {}
for k, v in time_stats.items():
results[k] = (np.mean(v), np.std(v) / np.sqrt(len(v)))
print(k, results[k])
print([f"{x:.3f}" for x in v])
print("...............")
print("With backward pass", 1 - results["run-b1-p1"][0] / results["run-b1-p0"][0])
print("Only forward pass", 1 - results["run-b0-p1"][0] / results["run-b0-p0"][0])
print("-----------------------------")
2 changes: 1 addition & 1 deletion examples/se2_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

torch.set_default_dtype(torch.double)

device = "cuda:0" if torch.cuda.is_available else "cpu"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch.random.manual_seed(1)
random.seed(1)
np.random.seed(1)
Expand Down
29 changes: 25 additions & 4 deletions tests/torchlie_tests/functional/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import torch

from torchlie.global_params import set_global_params


BATCH_SIZES_TO_TEST = [1, 20, (1, 2), (3, 4, 5), tuple()]
TEST_EPS = 5e-7
Expand Down Expand Up @@ -64,24 +66,27 @@ def get_test_cfg(op_name, dtype, dim, data_shape, module=None):
#
# `batch_size` can be a tuple.
def sample_inputs(input_types, batch_size, dtype, rng):
dev = "cuda:0" if torch.cuda.is_available() else "cpu"
if isinstance(batch_size, int):
batch_size = (batch_size,)

def _sample(input_type):
type_str, param = input_type

def _quat_sample():
q = torch.rand(*batch_size, param, dtype=dtype, generator=rng)
q = torch.rand(*batch_size, param, device=dev, dtype=dtype, generator=rng)
return q / torch.norm(q, dim=-1, keepdim=True)

sample_fns = {
"tangent": lambda: torch.rand(
*batch_size, param, dtype=dtype, generator=rng
*batch_size, param, device=dev, dtype=dtype, generator=rng
),
"group": lambda: param.rand(
*batch_size, device=dev, generator=rng, dtype=dtype
),
"group": lambda: param.rand(*batch_size, generator=rng, dtype=dtype),
"quat": lambda: _quat_sample(),
"matrix": lambda: torch.rand(
(*batch_size,) + param, generator=rng, dtype=dtype
(*batch_size,) + param, device=dev, generator=rng, dtype=dtype
),
}
return sample_fns[type_str]()
Expand Down Expand Up @@ -321,3 +326,19 @@ def check_left_project_broadcasting(
torch.autograd.gradcheck(
lie_group_fns.left_project, (g, t, out_dim), raise_exception=True
)


def check_log_map_passt(lie_group_fns, impl_module):
set_global_params({"_allow_passthrough_ops": True})
group = lie_group_fns.rand(
4, device="cuda:0" if torch.cuda.is_available() else "cpu", requires_grad=True
)
jlist = []
log_map_pt = lie_group_fns.log(group, jacobians=jlist)
grad_pt = torch.autograd.grad(log_map_pt.sum(), group)
log_map_ref = impl_module._log_autograd_fn(group)
jac_ref = impl_module._jlog_impl(group)[0][0]
grad_ref = torch.autograd.grad(log_map_ref.sum(), group)
torch.testing.assert_close(log_map_pt, log_map_ref)
torch.testing.assert_close(jlist[0], jac_ref)
torch.testing.assert_close(grad_pt, grad_ref)
7 changes: 6 additions & 1 deletion tests/torchlie_tests/functional/test_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
check_lie_group_function,
check_jacrev_binary,
check_jacrev_unary,
check_log_map_passt,
run_test_op,
)

Expand All @@ -43,7 +44,7 @@
@pytest.mark.parametrize("batch_size", BATCH_SIZES_TO_TEST)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_op(op_name, batch_size, dtype):
rng = torch.Generator()
rng = torch.Generator(device="cuda:0" if torch.cuda.is_available() else "cpu")
rng.manual_seed(0)
run_test_op(op_name, batch_size, dtype, rng, 6, (3, 4), se3_impl)

Expand Down Expand Up @@ -101,3 +102,7 @@ def test_left_project_broadcasting():
rng.manual_seed(0)
batch_sizes = [tuple(), (1, 2), (1, 1, 2), (2, 1), (2, 2), (2, 2, 2)]
check_left_project_broadcasting(SE3, batch_sizes, [0, 1, 2], (3, 4), rng)


def test_log_map_passt():
check_log_map_passt(SE3, se3_impl)
25 changes: 23 additions & 2 deletions tests/torchlie_tests/functional/test_so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torchlie.functional.so3_impl as so3_impl
from torchlie.functional import SO3

from torchlie.global_params import set_global_params

from .common import (
BATCH_SIZES_TO_TEST,
Expand All @@ -19,6 +19,7 @@
check_lie_group_function,
check_jacrev_binary,
check_jacrev_unary,
check_log_map_passt,
run_test_op,
)

Expand All @@ -45,7 +46,8 @@
@pytest.mark.parametrize("batch_size", BATCH_SIZES_TO_TEST)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_op(op_name, batch_size, dtype):
rng = torch.Generator()
set_global_params({"_faster_log_maps": True})
rng = torch.Generator(device="cuda:0" if torch.cuda.is_available() else "cpu")
rng.manual_seed(0)
run_test_op(op_name, batch_size, dtype, rng, 3, (3, 3), so3_impl)

Expand Down Expand Up @@ -102,3 +104,22 @@ def test_left_project_broadcasting():
rng.manual_seed(0)
batch_sizes = [tuple(), (1, 2), (1, 1, 2), (2, 1), (2, 2), (2, 2, 2)]
check_left_project_broadcasting(SO3, batch_sizes, [0, 1, 2], (3, 3), rng)


def test_log_map_passt():
check_log_map_passt(SO3, so3_impl)


# This tests that the CUDA implementation of sine axis returns the same result
# as the CPU implementation
@pytest.mark.parametrize("batch_size", [[1], [10], [2, 10]])
def test_sine_axis(batch_size):
set_global_params({"_faster_log_maps": True})
if not torch.cuda.is_available():
return
for _ in range(10):
g = so3_impl.rand(*batch_size)
g_cuda = g.to("cuda:0")
sa_1 = so3_impl._sine_axis_fn(g, g.shape[:-2])
sa_2 = so3_impl._sine_axis_fn(g_cuda, g.shape[:-2])
torch.testing.assert_close(sa_1, sa_2.cpu())
2 changes: 1 addition & 1 deletion theseus/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ def _as_tuple(s: str) -> Tuple[int, int, int]:
FutureWarning,
)

__version__ = "0.2.1"
__version__ = "0.2.2.dev0"
25 changes: 22 additions & 3 deletions theseus/embodied/misc/local_cost_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from theseus import CostFunction, CostWeight
from theseus.geometry import LieGroup
from theseus.global_params import _THESEUS_GLOBAL_PARAMS


class Local(CostFunction):
Expand All @@ -33,13 +34,31 @@ def __init__(
self.register_optim_vars(["var"])
self.register_aux_vars(["target"])

self._jac_cache: torch.Tensor = None

def error(self) -> torch.Tensor:
return self.target.local(self.var)

def jacobians(self) -> Tuple[List[torch.Tensor], torch.Tensor]:
Jlist: List[torch.Tensor] = []
error = self.target.local(self.var, jacobians=Jlist)
return [Jlist[1]], error
if _THESEUS_GLOBAL_PARAMS.fast_approx_local_jacobians:
if (
self._jac_cache is not None
and self._jac_cache.shape[0] == self.var.shape[0]
):
jacobian = self._jac_cache
else:
jacobian = torch.eye(
self.dim(), device=self.var.device, dtype=self.var.dtype
).repeat(self.var.shape[0], 1, 1)
self._jac_cache = jacobian
return (
[jacobian],
self.target.local(self.var),
)
else:
Jlist: List[torch.Tensor] = []
error = self.target.local(self.var, jacobians=Jlist)
return [Jlist[1]], error

def dim(self) -> int:
return self.var.dof()
Expand Down
2 changes: 2 additions & 0 deletions theseus/global_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class _TheseusGlobalParams:
so2_norm_eps_float64: float = 0
so2_matrix_eps_float64: float = 0
se2_near_zero_eps_float64: float = 0
fast_approx_local_jacobians: bool = False

def __init__(self):
self.reset()
Expand All @@ -50,6 +51,7 @@ def reset(self) -> None:
self.so2_matrix_eps_float64 = 4e-7
self.se2_near_zero_eps_float64 = 1e-6
self.se2_d_near_zero_eps_float64 = 1e-3
self.fast_approx_local_jacobians = False


_THESEUS_GLOBAL_PARAMS = _TheseusGlobalParams()
Expand Down
2 changes: 1 addition & 1 deletion torchlie/torchlie/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
__version__ = "0.1.0"
__version__ = "0.1.1.dev0"

from .global_params import reset_global_params, set_global_params
from .lie_tensor import ( # usort: skip
Expand Down
Loading

0 comments on commit c68c7f5

Please sign in to comment.