Skip to content

Commit

Permalink
Add theseus.geometry.functional.so3.inverse() (facebookresearch#374)
Browse files Browse the repository at this point in the history
* add SO3.inverse()

* add tests for SO3.inverse

* fixed the output bug for jinverse

* add a comment about op_autograd_fn

* add comments that _jxxx_impl might not exist for some operators
  • Loading branch information
fantaosha authored Nov 26, 2022
1 parent e4fa7aa commit f843071
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 2 deletions.
10 changes: 10 additions & 0 deletions tests/geometry/functional/test_so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,13 @@ def test_adjoint(batch_size: int, dtype: torch.dtype):
tangent_vector = torch.rand(batch_size, 3, dtype=dtype, generator=rng)
group = so3.exp(tangent_vector)
check_lie_group_function(so3, "adjoint", TEST_EPS, group)


@pytest.mark.parametrize("batch_size", [1, 20, 100])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_inverse(batch_size: int, dtype: torch.dtype):
rng = torch.Generator()
rng.manual_seed(0)
tangent_vector = torch.rand(batch_size, 3, dtype=dtype, generator=rng)
group = so3.exp(tangent_vector)
check_lie_group_function(so3, "inverse", TEST_EPS, group)
11 changes: 11 additions & 0 deletions theseus/geometry/functional/lie_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,20 @@
from .utils import check_jacobians_list

# There are four functions associated with each Lie group operator xxx.
# ----------------------------------------------------------------------------------
# _xxx_impl: analytic implementation of the operator, return xxx
# _jxxx_impl: analytic implementation of the operator jacobian, return jxxx and xxx
# _xxx_autograd_fn: a torch.autograd.Function wrapper of _xxx_impl
# _jxxx_autograd_fn: simply equivalent to _jxxx_impl for now
# ----------------------------------------------------------------------------------
# Note that _jxxx_impl might not exist for some operators.


def JInverseImplFactory(module):
def _jinverse_impl(group: torch.Tensor) -> Tuple[List[torch.Tensor], torch.Tensor]:
return [-module._adjoint_autograd_fn(group)], module._inverse_autograd_fn(group)

return _jinverse_impl


class UnaryOperator(torch.autograd.Function):
Expand All @@ -24,6 +34,7 @@ def forward(cls, ctx, input):


def UnaryOperatorFactory(module, op_name):
# Get autograd.Function wrapper of op and its jacobian
op_autograd_fn = getattr(module, "_" + op_name + "_autograd_fn")
jop_autograd_fn = getattr(module, "_j" + op_name + "_autograd_fn")

Expand Down
34 changes: 32 additions & 2 deletions theseus/geometry/functional/so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
DIM: int = 3


_module = get_module(__name__)


def check_group_tensor(tensor: torch.Tensor) -> bool:
with torch.no_grad():
if tensor.ndim != 3 or tensor.shape[1:] != (3, 3):
Expand Down Expand Up @@ -166,8 +169,6 @@ def backward(cls, ctx, grad_output):
return grad_input.view(-1, 3)


_module = get_module(__name__)

# TODO: Implement analytic backward for _jexp_impl
_exp_autograd_fn = Exp.apply
_jexp_autograd_fn = _jexp_impl
Expand Down Expand Up @@ -203,3 +204,32 @@ def backward(cls, ctx, grad_output):
_jadjoint_autograd_fn = None

adjoint = lie_group.UnaryOperatorFactory(_module, "adjoint")


# -----------------------------------------------------------------------------
# Inverse
# -----------------------------------------------------------------------------
def _inverse_impl(group: torch.Tensor) -> torch.Tensor:
if not check_group_tensor(group):
raise ValueError("Invalid data tensor for SO3.")
return group.transpose(1, 2)


_jinverse_impl = lie_group.JInverseImplFactory(_module)


class Inverse(lie_group.UnaryOperator):
@classmethod
def forward(cls, ctx, group):
group: torch.Tensor = cast(torch.Tensor, group)
return _inverse_impl(group)

@classmethod
def backward(cls, ctx, grad_output):
return grad_output.transpose(1, 2)


_inverse_autograd_fn = Inverse.apply
_jinverse_autograd_fn = _jinverse_impl

inverse, jinverse = lie_group.UnaryOperatorFactory(_module, "inverse")

0 comments on commit f843071

Please sign in to comment.