Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor SO3 and SE3 to be consistent with functorch #266

Merged
merged 25 commits into from
Sep 7, 2022
Merged
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
6d75099
fixed some bugs in SO3.log_map
fantaosha Aug 8, 2022
71cc455
refactor SO3 to be consistent with functorch
fantaosha Aug 8, 2022
7bbf68d
fixed a bug in SO3._project_impl
fantaosha Aug 8, 2022
03a7b50
add more tests for SO3
fantaosha Aug 8, 2022
789d173
SE3 refactored to be consistent with functorch
fantaosha Aug 8, 2022
4d5e651
simplify SO3 and SE3 for functorch
fantaosha Aug 8, 2022
71b09f6
refactor so2 to be consistent with functorch
fantaosha Aug 9, 2022
90d0915
torch.zeros() -> tensor.new_zeros()
fantaosha Aug 11, 2022
4976540
simplify the code using new_zeros
fantaosha Aug 11, 2022
60bcdda
refactor se2
fantaosha Aug 11, 2022
ab4a240
refactor the projection map for SE3
fantaosha Aug 11, 2022
08ca208
fixed a bug in SO2._rotate_from_cos_sin
fantaosha Aug 11, 2022
6fae8bb
fixed a bug for functorch
fantaosha Aug 11, 2022
efa187a
refactor SO3.log_map_impl
fantaosha Aug 12, 2022
942fc71
refactor SO3 and remove functorch context for log_map_impl() and to_q…
fantaosha Aug 25, 2022
5417f50
refactor SE3._log_map_impl
fantaosha Aug 25, 2022
edec950
SO3 refactored
fantaosha Aug 31, 2022
32db4c9
functorhc refactored
fantaosha Aug 31, 2022
1718ccc
add more warning info for functorch
fantaosha Aug 31, 2022
56a9331
fixed a bug in warnings message about tensor check for functorch
fantaosha Aug 31, 2022
65b3798
rename functorch context
fantaosha Sep 7, 2022
85495ab
rename lie_group_tensor to lie_group
fantaosha Sep 7, 2022
7911ff9
some changes are made
fantaosha Sep 7, 2022
f9d7407
rename lie_group_tensor_check to lie_group_check
fantaosha Sep 7, 2022
4770e88
fixed the logic bug
fantaosha Sep 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor SO3 and remove functorch context for log_map_impl() and to_q…
…uaternion()
  • Loading branch information
fantaosha committed Aug 25, 2022
commit 942fc711a535298c24cdb9a2ae60545a9c2889f7
302 changes: 213 additions & 89 deletions theseus/geometry/so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,44 +281,23 @@ def _log_map_impl(
)
ret = sine_axis * scale.view(-1, 1)

if _FunctorchContext.get_context():
# # theta ~ pi
ddiag = torch.diagonal(self.tensor, dim1=1, dim2=2)
# Find the index of major coloumns and diagonals
major = torch.logical_and(
ddiag[:, 1] > ddiag[:, 0], ddiag[:, 1] > ddiag[:, 2]
) + 2 * torch.logical_and(
ddiag[:, 2] > ddiag[:, 0], ddiag[:, 2] > ddiag[:, 1]
)
aux = torch.ones(self.shape[0], dtype=torch.bool)
sel_rows = 0.5 * (self[aux, major] + self[aux, :, major])
sel_rows[aux, major] -= cosine
axis = sel_rows / torch.where(
near_zero,
non_zero,
sel_rows.norm(dim=1),
).view(-1, 1)
sign_tmp = sine_axis[aux, major].sign()
sign = torch.where(sign_tmp != 0, sign_tmp, torch.ones_like(sign_tmp))
ret = torch.where(
near_pi.view(-1, 1), axis * (theta * sign).view(-1, 1), ret
)
else:
if near_pi.any():
ddiag = torch.diagonal(self[near_pi], dim1=1, dim2=2)
# Find the index of major coloumns and diagonals
major = torch.logical_and(
ddiag[:, 1] > ddiag[:, 0], ddiag[:, 1] > ddiag[:, 2]
) + 2 * torch.logical_and(
ddiag[:, 2] > ddiag[:, 0], ddiag[:, 2] > ddiag[:, 1]
)
sel_rows = 0.5 * (self[near_pi, major] + self[near_pi, :, major])
aux = torch.ones(sel_rows.shape[0], dtype=torch.bool)
sel_rows[aux, major] -= cosine[near_pi]
axis = sel_rows / sel_rows.norm(dim=1, keepdim=True)
sign_tmp = sine_axis[near_pi, major].sign()
sign = torch.where(sign_tmp != 0, sign_tmp, torch.ones_like(sign_tmp))
ret[near_pi] = axis * (theta[near_pi] * sign).view(-1, 1)
# # theta ~ pi
ddiag = torch.diagonal(self.tensor, dim1=1, dim2=2)
# Find the index of major coloumns and diagonals
major = torch.logical_and(
ddiag[:, 1] > ddiag[:, 0], ddiag[:, 1] > ddiag[:, 2]
) + 2 * torch.logical_and(ddiag[:, 2] > ddiag[:, 0], ddiag[:, 2] > ddiag[:, 1])
aux = torch.ones(self.shape[0], dtype=torch.bool)
sel_rows = 0.5 * (self[aux, major] + self[aux, :, major])
sel_rows[aux, major] -= cosine
axis = sel_rows / torch.where(
near_zero,
non_zero,
sel_rows.norm(dim=1),
).view(-1, 1)
sign_tmp = sine_axis[aux, major].sign()
sign = torch.where(sign_tmp != 0, sign_tmp, torch.ones_like(sign_tmp))
ret = torch.where(near_pi.view(-1, 1), axis * (theta * sign).view(-1, 1), ret)

if jacobians is not None:
SO3._check_jacobians_list(jacobians)
Expand Down Expand Up @@ -388,58 +367,29 @@ def to_quaternion(self) -> torch.Tensor:
ret[:, 0] = w
ret[:, 1:] = 0.5 * sine_axis / torch.where(near_pi, non_zero, w).view(-1, 1)

if _FunctorchContext.get_context():
# theta ~ pi
ddiag = torch.diagonal(self.tensor, dim1=1, dim2=2)
# Find the index of major coloumns and diagonals
major = torch.logical_and(
ddiag[:, 1] > ddiag[:, 0], ddiag[:, 1] > ddiag[:, 2]
) + 2 * torch.logical_and(
ddiag[:, 2] > ddiag[:, 0], ddiag[:, 2] > ddiag[:, 1]
)
aux = torch.ones(self.shape[0], dtype=torch.bool)
sel_rows = 0.5 * (self[aux, major] + self[aux, :, major])
cosine_near_pi = 0.5 * (self[:, 0, 0] + self[:, 1, 1] + self[:, 2, 2] - 1)
sel_rows[aux, major] -= cosine_near_pi
axis = (
sel_rows
/ torch.where(
near_zero.view(-1, 1),
non_zero.view(-1, 1),
sel_rows.norm(dim=1, keepdim=True),
)
* sine_axis[aux, major].sign().view(-1, 1)
)
sine_half_theta = (
(0.5 * (1 - cosine_near_pi)).clamp(0, 1).sqrt().view(-1, 1)
)
ret[:, 1:] = torch.where(
near_pi.view(-1, 1), axis * sine_half_theta, ret[:, 1:]
)
else:
# theta ~ pi
ddiag = torch.diagonal(self[near_pi], dim1=1, dim2=2)
# Find the index of major coloumns and diagonals
major = torch.logical_and(
ddiag[:, 1] > ddiag[:, 0], ddiag[:, 1] > ddiag[:, 2]
) + 2 * torch.logical_and(
ddiag[:, 2] > ddiag[:, 0], ddiag[:, 2] > ddiag[:, 1]
# theta ~ pi
ddiag = torch.diagonal(self.tensor, dim1=1, dim2=2)
# Find the index of major coloumns and diagonals
major = torch.logical_and(
ddiag[:, 1] > ddiag[:, 0], ddiag[:, 1] > ddiag[:, 2]
) + 2 * torch.logical_and(ddiag[:, 2] > ddiag[:, 0], ddiag[:, 2] > ddiag[:, 1])
aux = torch.ones(self.shape[0], dtype=torch.bool)
sel_rows = 0.5 * (self[aux, major] + self[aux, :, major])
cosine_near_pi = 0.5 * (self[:, 0, 0] + self[:, 1, 1] + self[:, 2, 2] - 1)
sel_rows[aux, major] -= cosine_near_pi
axis = (
sel_rows
/ torch.where(
near_zero.view(-1, 1),
non_zero.view(-1, 1),
sel_rows.norm(dim=1, keepdim=True),
)
sel_rows = 0.5 * (self[near_pi, major] + self[near_pi, :, major])
aux = torch.ones(sel_rows.shape[0], dtype=torch.bool)
cosine_near_pi = 0.5 * (
self[near_pi, 0, 0] + self[near_pi, 1, 1] + self[near_pi, 2, 2] - 1
)
sel_rows[aux, major] -= cosine_near_pi
axis = (
sel_rows
/ sel_rows.norm(dim=1, keepdim=True)
* sine_axis[near_pi, major].sign().view(-1, 1)
)
sine_half_theta = (
(0.5 * (1 - cosine_near_pi)).clamp(0, 1).sqrt().view(-1, 1)
)
ret[near_pi, 1:] = axis * sine_half_theta
* sine_axis[aux, major].sign().view(-1, 1)
)
sine_half_theta = (0.5 * (1 - cosine_near_pi)).clamp(0, 1).sqrt().view(-1, 1)
ret[:, 1:] = torch.where(
near_pi.view(-1, 1), axis * sine_half_theta, ret[:, 1:]
)

return ret

Expand Down Expand Up @@ -579,6 +529,180 @@ def unrotate(

return ret

def _deprecated_log_map_impl(
fantaosha marked this conversation as resolved.
Show resolved Hide resolved
self, jacobians: Optional[List[torch.Tensor]] = None
) -> torch.Tensor:
sine_axis = self.tensor.new_zeros(self.shape[0], 3)
sine_axis[:, 0] = 0.5 * (self[:, 2, 1] - self[:, 1, 2])
sine_axis[:, 1] = 0.5 * (self[:, 0, 2] - self[:, 2, 0])
sine_axis[:, 2] = 0.5 * (self[:, 1, 0] - self[:, 0, 1])
cosine = 0.5 * (self[:, 0, 0] + self[:, 1, 1] + self[:, 2, 2] - 1)
sine = sine_axis.norm(dim=1)
theta = torch.atan2(sine, cosine)

near_zero = theta < self._NEAR_ZERO_EPS

near_pi = 1 + cosine <= self._NEAR_PI_EPS
# theta != pi
near_zero_or_near_pi = torch.logical_or(near_zero, near_pi)
# Compute the approximation of theta / sin(theta) when theta is near to 0
non_zero = torch.ones(1, dtype=self.dtype, device=self.device)
sine_nz = torch.where(near_zero_or_near_pi, non_zero, sine)
scale = torch.where(
near_zero_or_near_pi,
1 + sine**2 / 6,
theta / sine_nz,
)
ret = sine_axis * scale.view(-1, 1)

if _FunctorchContext.get_context():
# # theta ~ pi
ddiag = torch.diagonal(self.tensor, dim1=1, dim2=2)
# Find the index of major coloumns and diagonals
major = torch.logical_and(
ddiag[:, 1] > ddiag[:, 0], ddiag[:, 1] > ddiag[:, 2]
) + 2 * torch.logical_and(
ddiag[:, 2] > ddiag[:, 0], ddiag[:, 2] > ddiag[:, 1]
)
aux = torch.ones(self.shape[0], dtype=torch.bool)
sel_rows = 0.5 * (self[aux, major] + self[aux, :, major])
sel_rows[aux, major] -= cosine
axis = sel_rows / torch.where(
near_zero,
non_zero,
sel_rows.norm(dim=1),
).view(-1, 1)
sign_tmp = sine_axis[aux, major].sign()
sign = torch.where(sign_tmp != 0, sign_tmp, torch.ones_like(sign_tmp))
ret = torch.where(
near_pi.view(-1, 1), axis * (theta * sign).view(-1, 1), ret
)
else:
if near_pi.any():
ddiag = torch.diagonal(self[near_pi], dim1=1, dim2=2)
# Find the index of major coloumns and diagonals
major = torch.logical_and(
ddiag[:, 1] > ddiag[:, 0], ddiag[:, 1] > ddiag[:, 2]
) + 2 * torch.logical_and(
ddiag[:, 2] > ddiag[:, 0], ddiag[:, 2] > ddiag[:, 1]
)
sel_rows = 0.5 * (self[near_pi, major] + self[near_pi, :, major])
aux = torch.ones(sel_rows.shape[0], dtype=torch.bool)
sel_rows[aux, major] -= cosine[near_pi]
axis = sel_rows / sel_rows.norm(dim=1, keepdim=True)
sign_tmp = sine_axis[near_pi, major].sign()
sign = torch.where(sign_tmp != 0, sign_tmp, torch.ones_like(sign_tmp))
ret[near_pi] = axis * (theta[near_pi] * sign).view(-1, 1)

if jacobians is not None:
SO3._check_jacobians_list(jacobians)

theta2 = theta**2
sine_theta = sine * theta
two_cosine_minus_two = 2 * cosine - 2
two_cosine_minus_two_nz = torch.where(
near_zero, non_zero, two_cosine_minus_two
)
theta2_nz = torch.where(near_zero, non_zero, theta2)

a = torch.where(
near_zero, 1 - theta2 / 12, -sine_theta / two_cosine_minus_two_nz
)
b = torch.where(
near_zero,
1.0 / 12 + theta2 / 720,
(sine_theta + two_cosine_minus_two)
/ (theta2_nz * two_cosine_minus_two_nz),
)

jac = (b.view(-1, 1) * ret).view(-1, 3, 1) * ret.view(-1, 1, 3)

half_ret = 0.5 * ret
jac[:, 0, 1] -= half_ret[:, 2]
jac[:, 1, 0] += half_ret[:, 2]
jac[:, 0, 2] += half_ret[:, 1]
jac[:, 2, 0] -= half_ret[:, 1]
jac[:, 1, 2] -= half_ret[:, 0]
jac[:, 2, 1] += half_ret[:, 0]

diag_jac = torch.diagonal(jac, dim1=1, dim2=2)
diag_jac += a.view(-1, 1)

jacobians.append(jac)

return ret

def _deprecated_to_quaternion(self) -> torch.Tensor:
fantaosha marked this conversation as resolved.
Show resolved Hide resolved
sine_axis = self.tensor.new_zeros(self.shape[0], 3)
sine_axis[:, 0] = 0.5 * (self[:, 2, 1] - self[:, 1, 2])
sine_axis[:, 1] = 0.5 * (self[:, 0, 2] - self[:, 2, 0])
sine_axis[:, 2] = 0.5 * (self[:, 1, 0] - self[:, 0, 1])
w = 0.5 * (1 + self[:, 0, 0] + self[:, 1, 1] + self[:, 2, 2]).clamp(0, 4).sqrt()

near_zero = w > 1 - self._NEAR_ZERO_EPS
near_pi = w <= self._NEAR_PI_EPS
non_zero = self.tensor.new_ones([1])

ret = self.tensor.new_zeros(self.shape[0], 4)
# theta != pi
ret[:, 0] = w
ret[:, 1:] = 0.5 * sine_axis / torch.where(near_pi, non_zero, w).view(-1, 1)

if _FunctorchContext.get_context():
# theta ~ pi
ddiag = torch.diagonal(self.tensor, dim1=1, dim2=2)
# Find the index of major coloumns and diagonals
major = torch.logical_and(
ddiag[:, 1] > ddiag[:, 0], ddiag[:, 1] > ddiag[:, 2]
) + 2 * torch.logical_and(
ddiag[:, 2] > ddiag[:, 0], ddiag[:, 2] > ddiag[:, 1]
)
aux = torch.ones(self.shape[0], dtype=torch.bool)
sel_rows = 0.5 * (self[aux, major] + self[aux, :, major])
cosine_near_pi = 0.5 * (self[:, 0, 0] + self[:, 1, 1] + self[:, 2, 2] - 1)
sel_rows[aux, major] -= cosine_near_pi
axis = (
sel_rows
/ torch.where(
near_zero.view(-1, 1),
non_zero.view(-1, 1),
sel_rows.norm(dim=1, keepdim=True),
)
* sine_axis[aux, major].sign().view(-1, 1)
)
sine_half_theta = (
(0.5 * (1 - cosine_near_pi)).clamp(0, 1).sqrt().view(-1, 1)
)
ret[:, 1:] = torch.where(
near_pi.view(-1, 1), axis * sine_half_theta, ret[:, 1:]
)
else:
# theta ~ pi
ddiag = torch.diagonal(self[near_pi], dim1=1, dim2=2)
# Find the index of major coloumns and diagonals
major = torch.logical_and(
ddiag[:, 1] > ddiag[:, 0], ddiag[:, 1] > ddiag[:, 2]
) + 2 * torch.logical_and(
ddiag[:, 2] > ddiag[:, 0], ddiag[:, 2] > ddiag[:, 1]
)
sel_rows = 0.5 * (self[near_pi, major] + self[near_pi, :, major])
aux = torch.ones(sel_rows.shape[0], dtype=torch.bool)
cosine_near_pi = 0.5 * (
self[near_pi, 0, 0] + self[near_pi, 1, 1] + self[near_pi, 2, 2] - 1
)
sel_rows[aux, major] -= cosine_near_pi
axis = (
sel_rows
/ sel_rows.norm(dim=1, keepdim=True)
* sine_axis[near_pi, major].sign().view(-1, 1)
)
sine_half_theta = (
(0.5 * (1 - cosine_near_pi)).clamp(0, 1).sqrt().view(-1, 1)
)
ret[near_pi, 1:] = axis * sine_half_theta

return ret


rand_so3 = SO3.rand
randn_so3 = SO3.randn