diff --git a/tests/labs/lie/functional/common.py b/tests/labs/lie/functional/common.py index 97bcf69c0..db4ef9b0e 100644 --- a/tests/labs/lie/functional/common.py +++ b/tests/labs/lie/functional/common.py @@ -47,6 +47,10 @@ def get_test_cfg(op_name, dtype, dim, data_shape, module=None): ((torch.randint(1, 5, ()).item(),) + data_shape), ]: all_input_types.append((("group", module), ("matrix", shape))) + if op_name == "normalize": + all_input_types.append((("matrix", data_shape),)) + if dtype == torch.float32: + atol = 2.5e-4 return all_input_types, atol diff --git a/tests/labs/lie/functional/test_se3.py b/tests/labs/lie/functional/test_se3.py index b5938542a..309b4c27f 100644 --- a/tests/labs/lie/functional/test_se3.py +++ b/tests/labs/lie/functional/test_se3.py @@ -33,6 +33,7 @@ "project", "left_act", "left_project", + "normalize", ], ) @pytest.mark.parametrize("batch_size", BATCH_SIZES_TO_TEST) diff --git a/tests/labs/lie/functional/test_so3.py b/tests/labs/lie/functional/test_so3.py index 995ef2717..2ea01eea6 100644 --- a/tests/labs/lie/functional/test_so3.py +++ b/tests/labs/lie/functional/test_so3.py @@ -34,6 +34,7 @@ "project", "left_act", "left_project", + "normalize", ], ) @pytest.mark.parametrize("batch_size", BATCH_SIZES_TO_TEST) diff --git a/theseus/labs/lie/functional/lie_group.py b/theseus/labs/lie/functional/lie_group.py index fbcd45fca..c08fd7cdd 100644 --- a/theseus/labs/lie/functional/lie_group.py +++ b/theseus/labs/lie/functional/lie_group.py @@ -2,12 +2,11 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -import torch import abc +from typing import Any, Callable, List, Tuple, Optional, Protocol +import torch from .constants import DeviceType -from typing import Callable, List, Tuple, Optional, Protocol from .utils import check_jacobians_list # There are four functions associated with each Lie group operator xxx. @@ -45,7 +44,7 @@ class UnaryOperator(torch.autograd.Function): @classmethod @abc.abstractmethod - def _forward_impl(cls, tensor: torch.Tensor) -> torch.Tensor: + def _forward_impl(cls, tensor: torch.Tensor) -> Any: pass @classmethod @@ -217,6 +216,7 @@ def __init__(self, module): self.exp, self.jexp = UnaryOperatorFactory(module, "exp") self.log, self.jlog = UnaryOperatorFactory(module, "log") self.adj = UnaryOperatorFactory(module, "adjoint")[0] + self.normalize = UnaryOperatorFactory(module, "normalize")[0] self.inv, self.jinv = UnaryOperatorFactory(module, "inverse") self.hat = UnaryOperatorFactory(module, "hat")[0] self.vee = UnaryOperatorFactory(module, "vee")[0] diff --git a/theseus/labs/lie/functional/se3_impl.py b/theseus/labs/lie/functional/se3_impl.py index 631ee75ff..4374ece23 100644 --- a/theseus/labs/lie/functional/se3_impl.py +++ b/theseus/labs/lie/functional/se3_impl.py @@ -30,6 +30,16 @@ def _impl(t_: torch.Tensor): checks_base(tensor, _impl) +def check_matrix_tensor(tensor: torch.Tensor): + def _impl(t_): + if t_.ndim != 3 or t_.shape[-2:] != (3, 4): + raise ValueError( + f"SE3 data tensors can only be 3x4 matrices, but got shape {t_.shape}." + ) + + checks_base(tensor, _impl) + + def check_transform_tensor(tensor: torch.Tensor): SO3.check_transform_tensor(tensor) @@ -1065,4 +1075,51 @@ def backward(cls, ctx, grad_output): left_project, jleft_project = lie_group.BinaryOperatorFactory(_module, "left_project") + +# ----------------------------------------------------------------------------- +# Normalize +# ----------------------------------------------------------------------------- +def _normalize_impl(matrix: torch.Tensor) -> torch.Tensor: + check_matrix_tensor(matrix) + rotation = SO3._normalize_impl_helper(matrix[..., :, :3])[0] + translation = matrix[..., :, 3:] + return torch.cat((rotation, translation), dim=-1) + + +class Normalize(lie_group.UnaryOperator): + @classmethod + def _forward_impl(cls, matrix): + check_matrix_tensor(matrix) + matrix: torch.Tensor = matrix + rotation, svd_info = SO3._normalize_impl_helper(matrix[..., :, :3]) + translation = matrix[..., :, 3:] + output = torch.cat((rotation, translation), dim=-1) + return output, svd_info + + @staticmethod + def setup_context(ctx, inputs, outputs): + # outputs is (normalized_out, svd_info) + svd_info = outputs[1] + ctx.save_for_backward( + svd_info["u"], svd_info["s"], svd_info["v"], svd_info["sign"] + ) + + @classmethod + def backward(cls, ctx, grad_output, _): + u, s, v, sign = ctx.saved_tensors + grad_input1 = SO3._normalize_backward_helper( + u, s, v, sign, grad_output[..., :, :3] + ) + grad_input2 = grad_output[..., :, 3:] + grad_input = torch.cat((grad_input1, grad_input2), dim=-1) + return grad_input, None + + +def _normalize_autograd_fn(matrix: torch.Tensor): + return Normalize.apply(matrix)[0] + + +_jnormalize_autograd_fn = None + + _fns = lie_group.LieGroupFns(_module) diff --git a/theseus/labs/lie/functional/so3_impl.py b/theseus/labs/lie/functional/so3_impl.py index 84fd0910b..67d40e0b7 100644 --- a/theseus/labs/lie/functional/so3_impl.py +++ b/theseus/labs/lie/functional/so3_impl.py @@ -40,6 +40,14 @@ def _impl(t_): checks_base(tensor, _impl) +def check_matrix_tensor(tensor: torch.Tensor): + def _impl(t_: torch.Tensor): + if t_.ndim != 3 or t_.shape[-2:] != (3, 3): + raise ValueError("Matrix tensors can only be 3x3 matrices.") + + checks_base(tensor, _impl) + + def check_tangent_vector(tangent_vector: torch.Tensor): def _impl(t_: torch.Tensor): _check = t_.ndim == 3 and t_.shape[1:] == (3, 1) @@ -975,4 +983,85 @@ def backward(cls, ctx, grad_output): _jleft_project_autograd_fn = _jleft_project_impl +# ----------------------------------------------------------------------------- +# Normalize +# ----------------------------------------------------------------------------- +def _normalize_impl_helper(matrix: torch.Tensor): + check_matrix_tensor(matrix) + u, s, v = torch.svd(matrix) + sign = torch.det(u @ v).view(-1, 1, 1) + vt = torch.cat( + (v[:, :, :2], torch.where(sign > 0, v[:, :, 2:], -v[:, :, 2:])), dim=-1 + ).transpose(1, 2) + return u @ vt, {"u": u, "s": s, "v": v, "sign": sign} + + +def _normalize_impl(matrix: torch.Tensor) -> torch.Tensor: + return _normalize_impl_helper(matrix)[0] + + +def _normalize_backward_helper( + u: torch.Tensor, + s: torch.Tensor, + v: torch.Tensor, + sign: torch.Tensor, + grad_output: torch.Tensor, +) -> torch.Tensor: + def _skew_symm(matrix: torch.Tensor) -> torch.Tensor: + return matrix - matrix.transpose(-1, -2) + + ut = u.transpose(1, 2) + vt = v.transpose(1, 2) + grad_u: torch.Tensor = grad_output @ torch.cat( + (v[:, :, :2], v[:, :, 2:] @ sign), dim=-1 + ) + grad_v: torch.Tensor = grad_output.transpose(1, 2) @ torch.cat( + (u[:, :, :2], u[:, :, 2:] @ sign), dim=-1 + ) + s_squared: torch.Tensor = s.pow(2) + F = s_squared.view(-1, 1, 3).expand(-1, 3, 3) - s_squared.view(-1, 3, 1).expand( + -1, 3, 3 + ) + F = torch.where(F == 0, grad_output.new_ones(1) * torch.inf, F) + F = F.pow(-1) + + u_term: torch.Tensor = u @ (F * _skew_symm(ut @ grad_u)) + u_term = torch.einsum("n...ij, nj->n...ij", u_term, s) + u_term = u_term @ vt + + v_term: torch.Tensor = (F * _skew_symm(vt @ grad_v)) @ vt + v_term = torch.einsum("ni, n...ij->n...ij", s, v_term) + v_term = u @ v_term + + return u_term + v_term + + +class Normalize(lie_group.UnaryOperator): + @classmethod + def _forward_impl(cls, matrix): + matrix: torch.Tensor = matrix + output, svd_info = _normalize_impl_helper(matrix) + return output, svd_info + + @staticmethod + def setup_context(ctx, inputs, outputs): + # outputs is (normalized_out, svd_info) + svd_info = outputs[1] + ctx.save_for_backward( + svd_info["u"], svd_info["s"], svd_info["v"], svd_info["sign"] + ) + + @classmethod + def backward(cls, ctx, grad_output, _): + u, s, v, sign = ctx.saved_tensors + return _normalize_backward_helper(u, s, v, sign, grad_output), None + + +def _normalize_autograd_fn(matrix: torch.Tensor): + return Normalize.apply(matrix)[0] + + +_jnormalize_autograd_fn = None + + _fns = lie_group.LieGroupFns(_module)