Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor SO3 and SE3 to be consistent with functorch #266

Merged
merged 25 commits into from
Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
6d75099
fixed some bugs in SO3.log_map
fantaosha Aug 8, 2022
71cc455
refactor SO3 to be consistent with functorch
fantaosha Aug 8, 2022
7bbf68d
fixed a bug in SO3._project_impl
fantaosha Aug 8, 2022
03a7b50
add more tests for SO3
fantaosha Aug 8, 2022
789d173
SE3 refactored to be consistent with functorch
fantaosha Aug 8, 2022
4d5e651
simplify SO3 and SE3 for functorch
fantaosha Aug 8, 2022
71b09f6
refactor so2 to be consistent with functorch
fantaosha Aug 9, 2022
90d0915
torch.zeros() -> tensor.new_zeros()
fantaosha Aug 11, 2022
4976540
simplify the code using new_zeros
fantaosha Aug 11, 2022
60bcdda
refactor se2
fantaosha Aug 11, 2022
ab4a240
refactor the projection map for SE3
fantaosha Aug 11, 2022
08ca208
fixed a bug in SO2._rotate_from_cos_sin
fantaosha Aug 11, 2022
6fae8bb
fixed a bug for functorch
fantaosha Aug 11, 2022
efa187a
refactor SO3.log_map_impl
fantaosha Aug 12, 2022
942fc71
refactor SO3 and remove functorch context for log_map_impl() and to_q…
fantaosha Aug 25, 2022
5417f50
refactor SE3._log_map_impl
fantaosha Aug 25, 2022
edec950
SO3 refactored
fantaosha Aug 31, 2022
32db4c9
functorhc refactored
fantaosha Aug 31, 2022
1718ccc
add more warning info for functorch
fantaosha Aug 31, 2022
56a9331
fixed a bug in warnings message about tensor check for functorch
fantaosha Aug 31, 2022
65b3798
rename functorch context
fantaosha Sep 7, 2022
85495ab
rename lie_group_tensor to lie_group
fantaosha Sep 7, 2022
7911ff9
some changes are made
fantaosha Sep 7, 2022
f9d7407
rename lie_group_tensor_check to lie_group_check
fantaosha Sep 7, 2022
4770e88
fixed the logic bug
fantaosha Sep 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
SE3 refactored to be consistent with functorch
  • Loading branch information
fantaosha committed Aug 8, 2022
commit 789d173b227e39b8bd9d6a83ae771b11cc9c2540
208 changes: 147 additions & 61 deletions theseus/geometry/se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

from typing import List, Optional, Union, cast
import warnings

import torch

Expand All @@ -13,6 +14,7 @@
from .lie_group import LieGroup
from .point_types import Point3
from .so3 import SO3
from .functorch import _FunctorchContext


class SE3(LieGroup):
Expand Down Expand Up @@ -110,6 +112,10 @@ def __str__(self) -> str:

def _adjoint_impl(self) -> torch.Tensor:
ret = torch.zeros(self.shape[0], 6, 6).to(dtype=self.dtype, device=self.device)

if _FunctorchContext.get_context():
ret = ret * self[0, 0, 0]

ret[:, :3, :3] = self[:, :3, :3]
ret[:, 3:, 3:] = self[:, :3, :3]
ret[:, :3, 3:] = SO3.hat(self[:, :3, 3]) @ self[:, :3, :3]
Expand All @@ -126,6 +132,9 @@ def _project_impl(
device=self.device,
)

if _FunctorchContext.get_context():
ret = ret * self.tensor.view(-1)[0] * euclidean_grad.view(-1)[0]

if is_sparse:
temp = torch.einsum(
"i...jk,i...jl->i...lk", euclidean_grad, self.tensor[:, :, :3]
Expand Down Expand Up @@ -176,15 +185,22 @@ def _hat_matrix_check(matrix: torch.Tensor):
if matrix.ndim != 3 or matrix.shape[1:] != (4, 4):
raise ValueError("Hat matrices of SE3 can only be 4x4 matrices")

if matrix[:, 3].abs().max().item() > HAT_EPS:
raise ValueError("The last row of hat matrices of SE3 can only be zero.")

if (
matrix[:, :3, :3].transpose(1, 2) + matrix[:, :3, :3]
).abs().max().item() > HAT_EPS:
raise ValueError(
"The 3x3 top-left corner of hat matrices of SE3 can only be skew-symmetric."
if _FunctorchContext.get_context():
warnings.warn(
"functorch is enabled and the skew-symmetry of hat matrices are not checked."
mhmukadam marked this conversation as resolved.
Show resolved Hide resolved
)
else:
if matrix[:, 3].abs().max().item() > HAT_EPS:
raise ValueError(
"The last row of hat matrices of SE3 can only be zero."
)

if (
matrix[:, :3, :3].transpose(1, 2) + matrix[:, :3, :3]
).abs().max().item() > HAT_EPS:
raise ValueError(
"The 3x3 top-left corner of hat matrices of SE3 can only be skew-symmetric."
)

@staticmethod
def exp_map(
Expand Down Expand Up @@ -349,46 +365,102 @@ def _log_map_impl(
self, jacobians: Optional[List[torch.Tensor]] = None
) -> torch.Tensor:

sine_axis = torch.zeros(self.shape[0], 3, dtype=self.dtype, device=self.device)
sine_axis[:, 0] = 0.5 * (self[:, 2, 1] - self[:, 1, 2])
sine_axis[:, 1] = 0.5 * (self[:, 0, 2] - self[:, 2, 0])
sine_axis[:, 2] = 0.5 * (self[:, 1, 0] - self[:, 0, 1])
cosine = 0.5 * (self[:, 0, 0] + self[:, 1, 1] + self[:, 2, 2] - 1)
sine = sine_axis.norm(dim=1)
theta = torch.atan2(sine, cosine)
theta2 = theta**2
non_zero = torch.ones(1, dtype=self.dtype, device=self.device)

near_zero = theta < self._NEAR_ZERO_EPS

# Compute the rotation
not_near_pi = 1 + cosine > self._NEAR_PI_EPS
# theta is not near pi
near_zero_not_near_pi = near_zero[not_near_pi]
# Compute the approximation of theta / sin(theta) when theta is near to 0
sine_nz = torch.where(near_zero_not_near_pi, non_zero, sine[not_near_pi])
scale = torch.where(
near_zero_not_near_pi,
1 + sine[not_near_pi] ** 2 / 6,
theta[not_near_pi] / sine_nz,
)
ret_ang = torch.zeros_like(sine_axis)
ret_ang[not_near_pi] = sine_axis[not_near_pi] * scale.view(-1, 1)

# theta is near pi
near_pi = ~not_near_pi
ddiag = torch.diagonal(self[near_pi], dim1=1, dim2=2)
# Find the index of major coloumns and diagonals
major = torch.logical_and(
ddiag[:, 1] > ddiag[:, 0], ddiag[:, 1] > ddiag[:, 2]
) + 2 * torch.logical_and(ddiag[:, 2] > ddiag[:, 0], ddiag[:, 2] > ddiag[:, 1])
sel_rows = 0.5 * (self[near_pi, major, :3] + self[near_pi, :3, major])
aux = torch.ones(sel_rows.shape[0], dtype=torch.bool)
sel_rows[aux, major] -= cosine[near_pi]
axis = sel_rows / sel_rows.norm(dim=1, keepdim=True)
sign_tmp = sine_axis[near_pi, major].sign()
sign = torch.where(sign_tmp != 0, sign_tmp, torch.ones_like(sign_tmp))
ret_ang[near_pi] = axis * (theta[near_pi] * sign).view(-1, 1)
if _FunctorchContext.get_context():
sine_axis = 0.5 * torch.cat(
[
(self[:, 2, 1] - self[:, 1, 2]).view(-1, 1),
(self[:, 0, 2] - self[:, 2, 0]).view(-1, 1),
(self[:, 1, 0] - self[:, 0, 1]).view(-1, 1),
],
dim=1,
)
cosine = 0.5 * (self[:, 0, 0] + self[:, 1, 1] + self[:, 2, 2] - 1)
sine = sine_axis.norm(dim=1)
theta = torch.atan2(sine, cosine)
theta2 = theta**2
non_zero = torch.ones(1, dtype=self.dtype, device=self.device)

near_zero = theta < self._NEAR_ZERO_EPS

# Compute the rotation
near_pi = 1 + cosine <= self._NEAR_PI_EPS
# theta is not near pi
near_zero_or_near_pi = torch.logical_or(near_zero, near_pi)
# Compute the approximation of theta / sin(theta) when theta is near to 0
sine_nz = torch.where(near_zero_or_near_pi, non_zero, sine)
scale = torch.where(
near_zero_or_near_pi,
1 + sine**2 / 6,
theta / sine_nz,
)
ret_ang = sine_axis * scale.view(-1, 1)

# theta is near pi
ddiag = torch.diagonal(self.tensor, dim1=1, dim2=2)
# Find the index of major coloumns and diagonals
major = torch.logical_and(
ddiag[:, 1] > ddiag[:, 0], ddiag[:, 1] > ddiag[:, 2]
) + 2 * torch.logical_and(
ddiag[:, 2] > ddiag[:, 0], ddiag[:, 2] > ddiag[:, 1]
)
aux = torch.ones(self.shape[0], dtype=torch.bool)
sel_rows = 0.5 * (self[aux, major, :3] + self[aux, :3, major])
sel_rows[aux, major] -= cosine
axis = sel_rows / torch.where(
near_zero.view(-1, 1),
non_zero.view(-1, 1),
sel_rows.norm(dim=1, keepdim=True),
)
sign_tmp = sine_axis[aux, major].sign()
sign = torch.where(sign_tmp != 0, sign_tmp, torch.ones_like(sign_tmp))
ret_ang = torch.where(
near_pi.view(-1, 1), axis * (theta * sign).view(-1, 1), ret_ang
)
else:
sine_axis = torch.zeros(
self.shape[0], 3, dtype=self.dtype, device=self.device
)
sine_axis[:, 0] = 0.5 * (self[:, 2, 1] - self[:, 1, 2])
sine_axis[:, 1] = 0.5 * (self[:, 0, 2] - self[:, 2, 0])
sine_axis[:, 2] = 0.5 * (self[:, 1, 0] - self[:, 0, 1])
cosine = 0.5 * (self[:, 0, 0] + self[:, 1, 1] + self[:, 2, 2] - 1)
sine = sine_axis.norm(dim=1)
theta = torch.atan2(sine, cosine)
theta2 = theta**2
non_zero = torch.ones(1, dtype=self.dtype, device=self.device)

near_zero = theta < self._NEAR_ZERO_EPS

# Compute the rotation
not_near_pi = 1 + cosine > self._NEAR_PI_EPS
# theta is not near pi
near_zero_not_near_pi = near_zero[not_near_pi]
# Compute the approximation of theta / sin(theta) when theta is near to 0
sine_nz = torch.where(near_zero_not_near_pi, non_zero, sine[not_near_pi])
scale = torch.where(
near_zero_not_near_pi,
1 + sine[not_near_pi] ** 2 / 6,
theta[not_near_pi] / sine_nz,
)
ret_ang = torch.zeros_like(sine_axis)
ret_ang[not_near_pi] = sine_axis[not_near_pi] * scale.view(-1, 1)

# theta is near pi
near_pi = ~not_near_pi
ddiag = torch.diagonal(self[near_pi], dim1=1, dim2=2)
# Find the index of major coloumns and diagonals
major = torch.logical_and(
ddiag[:, 1] > ddiag[:, 0], ddiag[:, 1] > ddiag[:, 2]
) + 2 * torch.logical_and(
ddiag[:, 2] > ddiag[:, 0], ddiag[:, 2] > ddiag[:, 1]
)
sel_rows = 0.5 * (self[near_pi, major, :3] + self[near_pi, :3, major])
aux = torch.ones(sel_rows.shape[0], dtype=torch.bool)
sel_rows[aux, major] -= cosine[near_pi]
axis = sel_rows / sel_rows.norm(dim=1, keepdim=True)
sign_tmp = sine_axis[near_pi, major].sign()
sign = torch.where(sign_tmp != 0, sign_tmp, torch.ones_like(sign_tmp))
ret_ang[near_pi] = axis * (theta[near_pi] * sign).view(-1, 1)

# Compute the translation
sine_theta = sine * theta
Expand Down Expand Up @@ -473,23 +545,31 @@ def _compose_impl(self, se3_2: LieGroup) -> "SE3":
se3_2 = cast(SE3, se3_2)
batch_size = max(self.shape[0], se3_2.shape[0])
ret = SE3()
ret.tensor = torch.zeros(batch_size, 3, 4, dtype=self.dtype, device=self.device)
ret[:, :, :3] = self[:, :, :3] @ se3_2[:, :, :3]
ret[:, :, 3] = self[:, :, 3]
if _FunctorchContext.get_context():
ret.tensor = torch.cat(
[self[:, :, :3] @ se3_2[:, :, :3], self[:, :, 3].view(-1, 3, 1)], dim=2
)
else:
ret.tensor = torch.zeros(
batch_size, 3, 4, dtype=self.dtype, device=self.device
)
ret[:, :, :3] = self[:, :, :3] @ se3_2[:, :, :3]
ret[:, :, 3] = self[:, :, 3]
ret[:, :, 3:] += self[:, :, :3] @ se3_2[:, :, 3:]

return ret

def _inverse_impl(self, get_jacobian: bool = False) -> "SE3":
ret = torch.zeros(self.shape[0], 3, 4).to(dtype=self.dtype, device=self.device)
rotT = self.tensor[:, :3, :3].transpose(1, 2)
ret[:, :, :3] = rotT
ret[:, :, 3] = -(rotT @ self.tensor[:, :3, 3].unsqueeze(2)).view(-1, 3)
ret_rot = self.tensor[:, :3, :3].transpose(1, 2)
ret_t = -(ret_rot @ self.tensor[:, :3, 3].unsqueeze(2)).view(-1, 3, 1)
ret = torch.cat([ret_rot, ret_t], dim=2)
# if self.tensor is a valid SE3, so is the inverse
return SE3(tensor=ret, strict=False)

def to_matrix(self) -> torch.Tensor:
ret = torch.zeros(self.shape[0], 4, 4).to(dtype=self.dtype, device=self.device)
if _FunctorchContext.get_context():
ret = ret * self[0, 0, 0]
ret[:, :3] = self.tensor
ret[:, 3, 3] = 1
return ret
Expand All @@ -516,13 +596,19 @@ def hat(tangent_vector: torch.Tensor) -> torch.Tensor:
_check = tangent_vector.ndim == 2 and tangent_vector.shape[1] == 6
if not _check:
raise ValueError("Invalid vee matrix for SE3.")
matrix = torch.zeros(tangent_vector.shape[0], 4, 4).to(
dtype=tangent_vector.dtype, device=tangent_vector.device
)
matrix[:, :3, :3] = SO3.hat(tangent_vector[:, 3:])
matrix[:, :3, 3] = tangent_vector[:, :3]

return matrix
matrix = torch.cat(
[SO3.hat(tangent_vector[:, 3:]), tangent_vector[:, :3].view(-1, 3, 1)],
dim=2,
)
zeros = torch.zeros(
tangent_vector.shape[0],
1,
4,
dtype=tangent_vector.dtype,
device=tangent_vector.device,
)
return torch.cat([matrix, zeros], dim=1)

@staticmethod
def vee(matrix: torch.Tensor) -> torch.Tensor:
Expand Down
Loading