Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement SE3.normalize() #506

Merged
merged 12 commits into from
May 2, 2023
Prev Previous commit
Next Next commit
add a helper to SO3.normalize
  • Loading branch information
fantaosha committed May 1, 2023
commit d1b7faa56a949d65e4d81c1d8c29b9194c7635ef
16 changes: 11 additions & 5 deletions theseus/labs/lie/functional/so3_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.

import torch
import math
from typing import cast, List, Tuple, Optional

from . import constants
Expand Down Expand Up @@ -987,7 +986,7 @@ def backward(cls, ctx, grad_output):
# -----------------------------------------------------------------------------
# Normalize
# -----------------------------------------------------------------------------
def _normalize_impl(matrix: torch.Tensor):
def _normalize_impl_helper(matrix: torch.Tensor):
check_matrix_tensor(matrix)
u, s, v = torch.svd(matrix)
sign = torch.det(u @ v).view(-1, 1, 1)
Expand All @@ -997,11 +996,15 @@ def _normalize_impl(matrix: torch.Tensor):
return u @ vt, {"u": u, "s": s, "v": v, "sign": sign}


def _normalize_impl(matrix: torch.Tensor):
return _normalize_impl_helper(matrix)[0]


class Normalize(lie_group.UnaryOperator):
@classmethod
def _forward_impl(cls, matrix):
matrix: torch.Tensor = matrix
output, svd_info = _normalize_impl(matrix)
output, svd_info = _normalize_impl_helper(matrix)
return output, svd_info

@staticmethod
Expand All @@ -1027,7 +1030,7 @@ def backward(cls, ctx, grad_output, _):
F = s_squared.view(-1, 1, 3).expand(-1, 3, 3) - s_squared.view(-1, 3, 1).expand(
-1, 3, 3
)
F = torch.where(F == 0, (torch.ones(1) * math.inf).expand(F.shape), F)
F = torch.where(F == 0, grad_output.new_ones(1) * torch.inf, F)
F = F.pow(-1)

u_term: torch.Tensor = u @ (F * (ut @ grad_u - grad_u.transpose(1, 2) @ u))
Expand All @@ -1042,7 +1045,10 @@ def backward(cls, ctx, grad_output, _):
return u_term + v_term, None


_normalize_autograd_fn = Normalize.apply
def _normalize_autograd_fn(matrix: torch.Tensor):
return Normalize.apply(matrix)[0]


_jnormalize_autograd_fn = None


Expand Down