diff --git a/evaluations/time_local_cost_backward.py b/evaluations/time_local_cost_backward.py new file mode 100644 index 000000000..7c0e756b3 --- /dev/null +++ b/evaluations/time_local_cost_backward.py @@ -0,0 +1,100 @@ +import argparse + +import numpy as np +import theseus as th +import torch +import tqdm +import torchlie.functional as lieF +from theseus.global_params import set_global_params as set_th_global_params +from torchlie.functional.lie_group import LieGroupFns +from torchlie.global_params import set_global_params as set_lie_global_params +from theseus.utils import Timer + + +def run( + backward: bool, + group_type: str, + dev: str, + batch_size: int, + rng: torch.Generator, + verbose_level: int, + timer: Timer, + timer_label: str, +): + theseus_cls = getattr(th, group_type) + lieF_cls: LieGroupFns = getattr(lieF, group_type) + p = torch.nn.Parameter(lieF_cls.rand(batch_size, device=dev, generator=rng)) + adam = torch.optim.Adam([p], lr={"SO3": 0.1, "SE3": 0.01}[group_type]) + a = theseus_cls(name="a") + b = theseus_cls( + tensor=lieF_cls.rand(batch_size, device=dev, generator=rng), name="b" + ) + o = th.Objective() + o.add(th.Local(a, b, th.ScaleCostWeight(1.0), name="d")) + layer = th.TheseusLayer(th.LevenbergMarquardt(o, max_iterations=3, step_size=0.1)) + layer.to(dev) + timer.start(timer_label) + for i in range(10): + + def _do(): + layer.forward( + input_tensors={"a": p.clone()}, + optimizer_kwargs={"damping": 0.1, "verbose": verbose_level > 1}, + ) + + if backward: + adam.zero_grad() + _do() + loss = o.error_metric().sum() + if verbose_level > 0: + print(loss.item()) + loss.backward() + adam.step() + else: + with torch.no_grad(): + _do() + timer.end() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("reps", type=int, default=1) + parser.add_argument("--b", type=int, default=1, help="batch size") + parser.add_argument("--v", type=int, help="verbosity_level", default=0) + parser.add_argument("--g", choices=["SO3", "SE3"], default="SE3", help="group type") + parser.add_argument("--w", type=int, default=1, help="warmup iters") + parser.add_argument("--dev", type=str, default="cpu", help="device") + args = parser.parse_args() + + rng = torch.Generator(device=args.dev) + rng.manual_seed(0) + timer = Timer(args.dev) + print(f"Timing device {timer.device}") + + for backward in [True, False]: + for p in [True, False]: + label = f"b{backward:1d}-p{p:1d}" + set_lie_global_params({"_allow_passthrough_ops": p}) + set_lie_global_params({"_faster_log_maps": p}) + set_th_global_params({"fast_approx_local_jacobians": p}) + for i in tqdm.tqdm(range(args.reps + args.w)): + run( + backward, + args.g, + args.dev, + args.b, + rng, + args.v, + timer, + f"run-{label}" if i > args.w else f"warmup-{label}", + ) + time_stats = timer.stats() + results = {} + for k, v in time_stats.items(): + results[k] = (np.mean(v), np.std(v) / np.sqrt(len(v))) + print(k, results[k]) + print([f"{x:.3f}" for x in v]) + print("...............") + print("With backward pass", 1 - results["run-b1-p1"][0] / results["run-b1-p0"][0]) + print("Only forward pass", 1 - results["run-b0-p1"][0] / results["run-b0-p0"][0]) + print("-----------------------------") diff --git a/examples/se2_planning.py b/examples/se2_planning.py index ebb80e9af..f10abe813 100644 --- a/examples/se2_planning.py +++ b/examples/se2_planning.py @@ -17,7 +17,7 @@ torch.set_default_dtype(torch.double) -device = "cuda:0" if torch.cuda.is_available else "cpu" +device = "cuda:0" if torch.cuda.is_available() else "cpu" torch.random.manual_seed(1) random.seed(1) np.random.seed(1) diff --git a/tests/torchlie_tests/functional/common.py b/tests/torchlie_tests/functional/common.py index 2c787c3f6..11b8fa159 100644 --- a/tests/torchlie_tests/functional/common.py +++ b/tests/torchlie_tests/functional/common.py @@ -6,6 +6,8 @@ import torch +from torchlie.global_params import set_global_params + BATCH_SIZES_TO_TEST = [1, 20, (1, 2), (3, 4, 5), tuple()] TEST_EPS = 5e-7 @@ -64,6 +66,7 @@ def get_test_cfg(op_name, dtype, dim, data_shape, module=None): # # `batch_size` can be a tuple. def sample_inputs(input_types, batch_size, dtype, rng): + dev = "cuda:0" if torch.cuda.is_available() else "cpu" if isinstance(batch_size, int): batch_size = (batch_size,) @@ -71,17 +74,19 @@ def _sample(input_type): type_str, param = input_type def _quat_sample(): - q = torch.rand(*batch_size, param, dtype=dtype, generator=rng) + q = torch.rand(*batch_size, param, device=dev, dtype=dtype, generator=rng) return q / torch.norm(q, dim=-1, keepdim=True) sample_fns = { "tangent": lambda: torch.rand( - *batch_size, param, dtype=dtype, generator=rng + *batch_size, param, device=dev, dtype=dtype, generator=rng + ), + "group": lambda: param.rand( + *batch_size, device=dev, generator=rng, dtype=dtype ), - "group": lambda: param.rand(*batch_size, generator=rng, dtype=dtype), "quat": lambda: _quat_sample(), "matrix": lambda: torch.rand( - (*batch_size,) + param, generator=rng, dtype=dtype + (*batch_size,) + param, device=dev, generator=rng, dtype=dtype ), } return sample_fns[type_str]() @@ -321,3 +326,19 @@ def check_left_project_broadcasting( torch.autograd.gradcheck( lie_group_fns.left_project, (g, t, out_dim), raise_exception=True ) + + +def check_log_map_passt(lie_group_fns, impl_module): + set_global_params({"_allow_passthrough_ops": True}) + group = lie_group_fns.rand( + 4, device="cuda:0" if torch.cuda.is_available() else "cpu", requires_grad=True + ) + jlist = [] + log_map_pt = lie_group_fns.log(group, jacobians=jlist) + grad_pt = torch.autograd.grad(log_map_pt.sum(), group) + log_map_ref = impl_module._log_autograd_fn(group) + jac_ref = impl_module._jlog_impl(group)[0][0] + grad_ref = torch.autograd.grad(log_map_ref.sum(), group) + torch.testing.assert_close(log_map_pt, log_map_ref) + torch.testing.assert_close(jlist[0], jac_ref) + torch.testing.assert_close(grad_pt, grad_ref) diff --git a/tests/torchlie_tests/functional/test_se3.py b/tests/torchlie_tests/functional/test_se3.py index c3af91c3b..6fe86cd7b 100644 --- a/tests/torchlie_tests/functional/test_se3.py +++ b/tests/torchlie_tests/functional/test_se3.py @@ -18,6 +18,7 @@ check_lie_group_function, check_jacrev_binary, check_jacrev_unary, + check_log_map_passt, run_test_op, ) @@ -43,7 +44,7 @@ @pytest.mark.parametrize("batch_size", BATCH_SIZES_TO_TEST) @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) def test_op(op_name, batch_size, dtype): - rng = torch.Generator() + rng = torch.Generator(device="cuda:0" if torch.cuda.is_available() else "cpu") rng.manual_seed(0) run_test_op(op_name, batch_size, dtype, rng, 6, (3, 4), se3_impl) @@ -101,3 +102,7 @@ def test_left_project_broadcasting(): rng.manual_seed(0) batch_sizes = [tuple(), (1, 2), (1, 1, 2), (2, 1), (2, 2), (2, 2, 2)] check_left_project_broadcasting(SE3, batch_sizes, [0, 1, 2], (3, 4), rng) + + +def test_log_map_passt(): + check_log_map_passt(SE3, se3_impl) diff --git a/tests/torchlie_tests/functional/test_so3.py b/tests/torchlie_tests/functional/test_so3.py index 8f8c077b8..7ceab2d15 100644 --- a/tests/torchlie_tests/functional/test_so3.py +++ b/tests/torchlie_tests/functional/test_so3.py @@ -9,7 +9,7 @@ import torchlie.functional.so3_impl as so3_impl from torchlie.functional import SO3 - +from torchlie.global_params import set_global_params from .common import ( BATCH_SIZES_TO_TEST, @@ -19,6 +19,7 @@ check_lie_group_function, check_jacrev_binary, check_jacrev_unary, + check_log_map_passt, run_test_op, ) @@ -45,7 +46,8 @@ @pytest.mark.parametrize("batch_size", BATCH_SIZES_TO_TEST) @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) def test_op(op_name, batch_size, dtype): - rng = torch.Generator() + set_global_params({"_faster_log_maps": True}) + rng = torch.Generator(device="cuda:0" if torch.cuda.is_available() else "cpu") rng.manual_seed(0) run_test_op(op_name, batch_size, dtype, rng, 3, (3, 3), so3_impl) @@ -102,3 +104,22 @@ def test_left_project_broadcasting(): rng.manual_seed(0) batch_sizes = [tuple(), (1, 2), (1, 1, 2), (2, 1), (2, 2), (2, 2, 2)] check_left_project_broadcasting(SO3, batch_sizes, [0, 1, 2], (3, 3), rng) + + +def test_log_map_passt(): + check_log_map_passt(SO3, so3_impl) + + +# This tests that the CUDA implementation of sine axis returns the same result +# as the CPU implementation +@pytest.mark.parametrize("batch_size", [[1], [10], [2, 10]]) +def test_sine_axis(batch_size): + set_global_params({"_faster_log_maps": True}) + if not torch.cuda.is_available(): + return + for _ in range(10): + g = so3_impl.rand(*batch_size) + g_cuda = g.to("cuda:0") + sa_1 = so3_impl._sine_axis_fn(g, g.shape[:-2]) + sa_2 = so3_impl._sine_axis_fn(g_cuda, g.shape[:-2]) + torch.testing.assert_close(sa_1, sa_2.cpu()) diff --git a/theseus/_version.py b/theseus/_version.py index 8a55317cc..c8c83297d 100644 --- a/theseus/_version.py +++ b/theseus/_version.py @@ -33,4 +33,4 @@ def _as_tuple(s: str) -> Tuple[int, int, int]: FutureWarning, ) -__version__ = "0.2.1" +__version__ = "0.2.2.dev0" diff --git a/theseus/embodied/misc/local_cost_fn.py b/theseus/embodied/misc/local_cost_fn.py index 02b6391c5..c64666a51 100644 --- a/theseus/embodied/misc/local_cost_fn.py +++ b/theseus/embodied/misc/local_cost_fn.py @@ -9,6 +9,7 @@ from theseus import CostFunction, CostWeight from theseus.geometry import LieGroup +from theseus.global_params import _THESEUS_GLOBAL_PARAMS class Local(CostFunction): @@ -33,13 +34,31 @@ def __init__( self.register_optim_vars(["var"]) self.register_aux_vars(["target"]) + self._jac_cache: torch.Tensor = None + def error(self) -> torch.Tensor: return self.target.local(self.var) def jacobians(self) -> Tuple[List[torch.Tensor], torch.Tensor]: - Jlist: List[torch.Tensor] = [] - error = self.target.local(self.var, jacobians=Jlist) - return [Jlist[1]], error + if _THESEUS_GLOBAL_PARAMS.fast_approx_local_jacobians: + if ( + self._jac_cache is not None + and self._jac_cache.shape[0] == self.var.shape[0] + ): + jacobian = self._jac_cache + else: + jacobian = torch.eye( + self.dim(), device=self.var.device, dtype=self.var.dtype + ).repeat(self.var.shape[0], 1, 1) + self._jac_cache = jacobian + return ( + [jacobian], + self.target.local(self.var), + ) + else: + Jlist: List[torch.Tensor] = [] + error = self.target.local(self.var, jacobians=Jlist) + return [Jlist[1]], error def dim(self) -> int: return self.var.dof() diff --git a/theseus/global_params.py b/theseus/global_params.py index e0ad06346..d533363af 100644 --- a/theseus/global_params.py +++ b/theseus/global_params.py @@ -28,6 +28,7 @@ class _TheseusGlobalParams: so2_norm_eps_float64: float = 0 so2_matrix_eps_float64: float = 0 se2_near_zero_eps_float64: float = 0 + fast_approx_local_jacobians: bool = False def __init__(self): self.reset() @@ -50,6 +51,7 @@ def reset(self) -> None: self.so2_matrix_eps_float64 = 4e-7 self.se2_near_zero_eps_float64 = 1e-6 self.se2_d_near_zero_eps_float64 = 1e-3 + self.fast_approx_local_jacobians = False _THESEUS_GLOBAL_PARAMS = _TheseusGlobalParams() diff --git a/torchlie/torchlie/__init__.py b/torchlie/torchlie/__init__.py index f5f980425..05da66b84 100644 --- a/torchlie/torchlie/__init__.py +++ b/torchlie/torchlie/__init__.py @@ -2,7 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -__version__ = "0.1.0" +__version__ = "0.1.1.dev0" from .global_params import reset_global_params, set_global_params from .lie_tensor import ( # usort: skip diff --git a/torchlie/torchlie/functional/lie_group.py b/torchlie/torchlie/functional/lie_group.py index 1fb56aa41..0ed28d445 100644 --- a/torchlie/torchlie/functional/lie_group.py +++ b/torchlie/torchlie/functional/lie_group.py @@ -7,6 +7,8 @@ import torch +from torchlie.global_params import _TORCHLIE_GLOBAL_PARAMS as LIE_PARAMS + from .constants import DeviceType from .utils import check_jacobians_list @@ -18,6 +20,10 @@ # _jxxx_autograd_fn: simply equivalent to _jxxx_impl for now # ---------------------------------------------------------------------------------- # Note that _jxxx_impl might not exist for some operators. +# +# Some operators support a _xxx_passthrough_fn, which returns the same values as +# _xxx_autograd_fn in forward pass, but takes the output of _jxxx_autograd_fn as +# extra non-differentiable inputs to avoid computing operators twice. def JInverseImplFactory(module): @@ -42,6 +48,42 @@ def _left_project_impl( return _left_project_impl +# This class is used by `UnaryOperatorFactory` to +# avoid computing the operator twice in function calls of the form +# op(group, jacobians_list=jlist). +# This is functionally equivalent to `UnaryOperator` objects, but +# it receives the operator's result and jacobian as extra inputs. +# Usage is then: +# jac, res = _jop_impl(group) +# op_result = passthrough_fn(group, res, jac) +# This connects `op_result` to the compute graph with custom +# backward implementation, while `jac` uses torch default autograd. +class _UnaryPassthroughFn(torch.autograd.Function): + generate_vmap_rule = True + + @classmethod + @abc.abstractmethod + def _backward_impl( + cls, group: torch.Tensor, jacobian: torch.Tensor, grad_output: torch.Tensor + ) -> torch.Tensor: + pass + + @classmethod + def forward(cls, group, op_result, jacobian): + return op_result + + @classmethod + def setup_context(cls, ctx, inputs, outputs): + ctx.save_for_backward(inputs[0], inputs[2]) + + @classmethod + def backward(cls, ctx, grad_output): + grad = cls._backward_impl( + ctx.saved_tensors[0], ctx.saved_tensors[1], grad_output + ) + return grad, None, None + + class UnaryOperator(torch.autograd.Function): generate_vmap_rule = True @@ -95,8 +137,9 @@ def UnaryOperatorFactory( module, op_name ) -> Tuple[UnaryOperatorOpFnType, UnaryOperatorJOpFnType]: # Get autograd.Function wrapper of op and its jacobian - op_autograd_fn = getattr(module, "_" + op_name + "_autograd_fn") - jop_autograd_fn = getattr(module, "_j" + op_name + "_autograd_fn") + op_autograd_fn = getattr(module, f"_{op_name}_autograd_fn") + jop_autograd_fn = getattr(module, f"_j{op_name}_autograd_fn") + op_passthrough_fn = getattr(module, f"_{op_name}_passthrough_fn", None) def op( input: torch.Tensor, @@ -105,8 +148,10 @@ def op( if jacobians is not None: _check_jacobians_supported(jop_autograd_fn, module.NAME, op_name) check_jacobians_list(jacobians) - jacobians_op = jop_autograd_fn(input)[0] + jacobians_op, ret = jop_autograd_fn(input) jacobians.append(jacobians_op[0]) + if LIE_PARAMS._allow_passthrough_ops and op_passthrough_fn is not None: + return op_passthrough_fn(input, ret, jacobians_op[0]) return op_autograd_fn(input) def jop(input: torch.Tensor) -> Tuple[List[torch.Tensor], torch.Tensor]: diff --git a/torchlie/torchlie/functional/se3_impl.py b/torchlie/torchlie/functional/se3_impl.py index 15d46b8bb..7c51ac8f3 100644 --- a/torchlie/torchlie/functional/se3_impl.py +++ b/torchlie/torchlie/functional/se3_impl.py @@ -484,6 +484,15 @@ def _jlog_impl(group: torch.Tensor) -> Tuple[List[torch.Tensor], torch.Tensor]: return [jac], tangent_vector +def _log_backward( + group: torch.Tensor, jacobian: torch.Tensor, grad_output: torch.Tensor +) -> torch.Tensor: + jac_by_g = (jacobian.transpose(-1, -2) @ grad_output.unsqueeze(-1)).squeeze(-1) + jac_by_g[..., 3:] *= 0.5 + temp2: torch.Tensor = lift(jac_by_g) + return group[..., :3] @ temp2 + + class Log(lie_group.UnaryOperator): @classmethod def _forward_impl(cls, group): @@ -499,17 +508,21 @@ def setup_context(cls, ctx, inputs, outputs): @classmethod def backward(cls, ctx, grad_output): group: torch.Tensor = ctx.saved_tensors[1] - jacobians = _jlog_impl(group)[0][0] - jacobians[..., 3:] *= 0.5 - temp: torch.Tensor = lift( - (jacobians.transpose(-1, -2) @ grad_output.unsqueeze(-1)).squeeze(-1) - ) - return group[..., :3] @ temp + return _log_backward(group, _jlog_impl(group)[0][0], grad_output) + + +class _LogPassthroughWrapper(lie_group._UnaryPassthroughFn): + @classmethod + def _backward_impl( + cls, group: torch.Tensor, jacobian: torch.Tensor, grad_output: torch.Tensor + ) -> torch.Tensor: + return _log_backward(group, jacobian, grad_output) # TODO: Implement analytic backward for _jlog_impl _log_autograd_fn = Log.apply _jlog_autograd_fn = _jlog_impl +_log_passthrough_fn = _LogPassthroughWrapper.apply # ----------------------------------------------------------------------------- diff --git a/torchlie/torchlie/functional/so3_impl.py b/torchlie/torchlie/functional/so3_impl.py index ba75023af..538af8630 100644 --- a/torchlie/torchlie/functional/so3_impl.py +++ b/torchlie/torchlie/functional/so3_impl.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Optional, Tuple, cast +from typing import Dict, List, Optional, Tuple, cast import torch @@ -355,15 +355,41 @@ def backward(cls, ctx, grad_output): _jexp_autograd_fn = _jexp_impl +_UPPER_IDX_3x3_CUDA: Dict[str, torch.Tensor] = None +if torch.cuda.is_available(): + _UPPER_IDX_3x3_CUDA = { + f"cuda:{i}": torch.triu_indices(3, 3, offset=1).to(f"cuda:{i}").flip(-1) + for i in range(torch.cuda.device_count()) + } + + +def _sine_axis_fn(group: torch.Tensor, size: torch.Size) -> torch.Tensor: + if LIE_PARAMS._faster_log_maps: + if group.is_cuda: + g_minus_gt = 0.5 * (group.adjoint() - group) + upper_idx = _UPPER_IDX_3x3_CUDA[str(group.device)] + sine_axis = g_minus_gt[..., upper_idx[0], upper_idx[1]] + sine_axis[..., 1] *= -1 + else: + sine_axis = group.new_zeros(*size, 3) + sine_axis[..., 0] = group[..., 2, 1] - group[..., 1, 2] + sine_axis[..., 1] = group[..., 0, 2] - group[..., 2, 0] + sine_axis[..., 2] = group[..., 1, 0] - group[..., 0, 1] + sine_axis *= 0.5 + else: + sine_axis = group.new_zeros(*size, 3) + sine_axis[..., 0] = 0.5 * (group[..., 2, 1] - group[..., 1, 2]) + sine_axis[..., 1] = 0.5 * (group[..., 0, 2] - group[..., 2, 0]) + sine_axis[..., 2] = 0.5 * (group[..., 1, 0] - group[..., 0, 1]) + return sine_axis + + # ----------------------------------------------------------------------------- # Logarithm Map # ----------------------------------------------------------------------------- def _log_impl_helper(group: torch.Tensor): size = get_group_size(group) - sine_axis = group.new_zeros(*size, 3) - sine_axis[..., 0] = 0.5 * (group[..., 2, 1] - group[..., 1, 2]) - sine_axis[..., 1] = 0.5 * (group[..., 0, 2] - group[..., 2, 0]) - sine_axis[..., 2] = 0.5 * (group[..., 1, 0] - group[..., 0, 1]) + sine_axis = _sine_axis_fn(group, size) cosine = 0.5 * (group.diagonal(dim1=-1, dim2=-2).sum(dim=-1) - 1) sine = sine_axis.norm(dim=-1) theta = torch.atan2(sine, cosine) @@ -460,6 +486,16 @@ def _jlog_impl(group: torch.Tensor) -> Tuple[List[torch.Tensor], torch.Tensor]: return [jac], tangent_vector +def _log_backward( + group: torch.Tensor, jacobian: torch.Tensor, grad_output: torch.Tensor +) -> torch.Tensor: + jacobian = 0.5 * jacobian + temp = _lift_autograd_fn( + (jacobian.transpose(-2, -1) @ grad_output.unsqueeze(-1)).squeeze(-1) + ) + return group @ temp + + class Log(lie_group.UnaryOperator): @classmethod def _forward_impl(cls, group): @@ -475,16 +511,21 @@ def setup_context(cls, ctx, inputs, outputs): @classmethod def backward(cls, ctx, grad_output): group: torch.Tensor = ctx.saved_tensors[1] - jacobians = 0.5 * _jlog_impl(group)[0][0] - temp: torch.Tensor = _lift_autograd_fn( - (jacobians.transpose(-2, -1) @ grad_output.unsqueeze(-1)).squeeze(-1) - ) - return group @ temp + return _log_backward(group, _jlog_impl(group)[0][0], grad_output) + + +class _LogPassthroughWrapper(lie_group._UnaryPassthroughFn): + @classmethod + def _backward_impl( + cls, group: torch.Tensor, jacobian: torch.Tensor, grad_output: torch.Tensor + ) -> torch.Tensor: + return _log_backward(group, jacobian, grad_output) # TODO: Implement analytic backward for _jlog_impl _log_autograd_fn = Log.apply _jlog_autograd_fn = _jlog_impl +_log_passthrough_fn = _LogPassthroughWrapper.apply # ----------------------------------------------------------------------------- diff --git a/torchlie/torchlie/global_params.py b/torchlie/torchlie/global_params.py index eafc54856..00e7e8e9e 100644 --- a/torchlie/torchlie/global_params.py +++ b/torchlie/torchlie/global_params.py @@ -30,6 +30,8 @@ class _TorchLieGlobalParams: so3_quat_eps_float64: float = 0 so3_hat_eps_float64: float = 0 se3_hat_eps_float64: float = 0 + _allow_passthrough_ops: bool = False + _faster_log_maps: bool = False def __init__(self): self.reset() diff --git a/tutorials/05_differentiable_motion_planning.ipynb b/tutorials/05_differentiable_motion_planning.ipynb index 763673adb..434977747 100644 --- a/tutorials/05_differentiable_motion_planning.ipynb +++ b/tutorials/05_differentiable_motion_planning.ipynb @@ -41,7 +41,7 @@ "\n", "torch.set_default_dtype(torch.double)\n", "\n", - "device = \"cuda:0\" if torch.cuda.is_available else \"cpu\"\n", + "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", "torch.random.manual_seed(1)\n", "random.seed(1)\n", "np.random.seed(1)\n",