Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-batch support for SE(3) unary operators #512

Merged
merged 14 commits into from
May 4, 2023
Prev Previous commit
Next Next commit
a minor refactor of the code
  • Loading branch information
fantaosha committed May 4, 2023
commit 287360051dc477b81a51dec77ac6db2f6093f4c0
4 changes: 2 additions & 2 deletions theseus/labs/lie/functional/se3_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def _jexp_impl_helper(

jac[..., :3, 3:] = rotation.transpose(-1, -2) @ jac_temp_t

return jac
return jac, (None,)


def _jexp_impl(
Expand All @@ -278,7 +278,7 @@ def _jexp_impl(
theta_minus_sine_by_theta3_rot = torch.where(
near_zero, torch.zeros_like(theta), theta_minus_sine_by_theta3_t
luisenp marked this conversation as resolved.
Show resolved Hide resolved
)
jac = _jexp_impl_helper(
jac, _ = _jexp_impl_helper(
tangent_vector,
ret[..., :3],
theta,
Expand Down
16 changes: 2 additions & 14 deletions theseus/labs/lie/functional/so3_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,7 @@ def rand(
) -> torch.Tensor:
# Reference:
# https://web.archive.org/web/20211105205926/http://planning.cs.uiuc.edu/node198.html
u = torch.rand(
3,
*size,
generator=generator,
dtype=dtype,
device=device,
)
u = torch.rand(3, *size, generator=generator, dtype=dtype, device=device)
u1 = u[0]
u2, u3 = u[1:3] * 2 * constants.PI

Expand Down Expand Up @@ -160,13 +154,7 @@ def randn(
) -> torch.Tensor:
ret = _exp_autograd_fn(
constants.PI
* torch.randn(
*size,
3,
generator=generator,
dtype=dtype,
device=device,
)
* torch.randn(*size, 3, generator=generator, dtype=dtype, device=device)
)
ret.requires_grad_(requires_grad)
return ret
Expand Down