Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ManifoldGaussian class for messages in belief propagation #121

Merged
merged 30 commits into from
Apr 20, 2022
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
78d991e
Gaussian class to wrap Manifold class and lam matrix for inverse cova…
joeaortiz Mar 16, 2022
978e150
reformatted
joeaortiz Mar 16, 2022
5e31531
restored original manifold file
joeaortiz Apr 4, 2022
10389ec
Merge branch 'main' into joe.add_manifold_covariance
joeaortiz Apr 4, 2022
312bd38
initial attempt at marginal class , need to handle batch dim
joeaortiz Apr 4, 2022
90572c8
Merge branch 'main' into joe.add_manifold_covariance
joeaortiz Apr 5, 2022
19c9d34
added standard fns dtype, to, copy, update
joeaortiz Apr 7, 2022
167bb84
single to call in init
joeaortiz Apr 7, 2022
572486a
renamed ManifoldGaussian
joeaortiz Apr 7, 2022
4cc5f78
setting precision in init with checks
joeaortiz Apr 8, 2022
a8ecb77
update function requires mean and precision
joeaortiz Apr 8, 2022
23cc2f9
fixed naming in init
joeaortiz Apr 11, 2022
00c7414
manifold gaussian tests
joeaortiz Apr 11, 2022
212a5d0
retract and local gaussian fns
joeaortiz Apr 11, 2022
e364a84
check precision is a symmetric matrix
joeaortiz Apr 11, 2022
58245a0
moved retract and local gaussian to manifold_gaussian to avoid circul…
joeaortiz Apr 12, 2022
0e9e93a
added ManifoldGaussian to inits
joeaortiz Apr 13, 2022
b9b7bbf
minor edits
joeaortiz Apr 13, 2022
b3b867e
fixed dtype error in se3 that appeared in unit tests
joeaortiz Apr 13, 2022
bce0e77
add checks for local_gaussian
joeaortiz Apr 13, 2022
d507742
tests for local and retract gaussian
joeaortiz Apr 13, 2022
eecf1a7
import from th.
joeaortiz Apr 13, 2022
6812d9b
added local_gaussian retract_gaussian to init, minor fix
joeaortiz Apr 13, 2022
8a97d10
minor fix on value error
joeaortiz Apr 13, 2022
218d754
Merge branch main into joe.add_manifold_covariance
joeaortiz Apr 19, 2022
8ac7967
fixed copy bug and added comments
joeaortiz Apr 19, 2022
e19a8d7
random precision matrix in unit tests
joeaortiz Apr 19, 2022
4e98aff
fix for random precision
joeaortiz Apr 19, 2022
7528e60
init precision with identity
joeaortiz Apr 19, 2022
6a7b1a4
fixed typo
joeaortiz Apr 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion theseus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,14 @@
randn_se2,
randn_se3,
)
from .optimizer import DenseLinearization, SparseLinearization, VariableOrdering
from .optimizer import (
DenseLinearization,
SparseLinearization,
VariableOrdering,
ManifoldGaussian,
local_gaussian,
retract_gaussian,
)
from .optimizer.linear import (
CholeskyDenseSolver,
CholmodSparseSolver,
Expand Down
16 changes: 14 additions & 2 deletions theseus/geometry/se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,14 +285,26 @@ def exp_map(

jac[:, 3:, 3:] = jac[:, :3, :3]

minus_one_by_twelve = torch.full(
near_zero.shape,
-1 / 12.0,
dtype=sine_by_theta.dtype,
device=sine_by_theta.device,
)
d_one_minus_cosine_by_theta2 = torch.where(
near_zero,
-1 / 12.0,
minus_one_by_twelve,
mhmukadam marked this conversation as resolved.
Show resolved Hide resolved
(sine_by_theta - 2 * one_minus_cosine_by_theta2) / theta2_nz,
)
minus_one_by_sixty = torch.full(
near_zero.shape,
-1 / 60.0,
dtype=one_minus_cosine_by_theta2.dtype,
device=one_minus_cosine_by_theta2.device,
)
d_theta_minus_sine_by_theta3 = torch.where(
near_zero,
-1 / 60.0,
minus_one_by_sixty,
(one_minus_cosine_by_theta2 - 3 * theta_minus_sine_by_theta3_t)
/ theta2_nz,
)
Expand Down
1 change: 1 addition & 0 deletions theseus/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .dense_linearization import DenseLinearization
from .linearization import Linearization
from .manifold_gaussian import ManifoldGaussian, local_gaussian, retract_gaussian
from .optimizer import Optimizer, OptimizerInfo
from .sparse_linearization import SparseLinearization
from .variable_ordering import VariableOrdering
156 changes: 156 additions & 0 deletions theseus/optimizer/manifold_gaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from itertools import count
from typing import List, Optional, Sequence, Tuple

import torch

from theseus.geometry import LieGroup, Manifold


class ManifoldGaussian:
_ids = count(0)

def __init__(
self,
mean: Sequence[Manifold],
precision: Optional[torch.Tensor] = None,
name: Optional[str] = None,
):
self._id = next(ManifoldGaussian._ids)
if name is None:
name = f"{self.__class__.__name__}__{self._id}"
self.name = name
joeaortiz marked this conversation as resolved.
Show resolved Hide resolved

dof = 0
for v in mean:
dof += v.dof()
self._dof = dof

self.mean = mean
if precision is None:
precision = torch.zeros(mean[0].shape[0], self.dof, self.dof).to(
joeaortiz marked this conversation as resolved.
Show resolved Hide resolved
dtype=mean[0].dtype, device=mean[0].device
)
self.update(mean, precision)
mhmukadam marked this conversation as resolved.
Show resolved Hide resolved

@property
def dof(self) -> int:
return self._dof

@property
def device(self) -> torch.device:
return self.mean[0].device

@property
def dtype(self) -> torch.dtype:
return self.mean[0].dtype

# calls to() on the internal tensors
def to(self, *args, **kwargs):
for var in self.mean:
var = var.to(*args, **kwargs)
self.precision = self.precision.to(*args, **kwargs)

def copy(self, new_name: Optional[str] = None) -> "ManifoldGaussian":
if not new_name:
new_name = f"{self.name}_copy"
mean_copy = [var.copy() for var in self.mean]
return ManifoldGaussian(mean_copy, name=new_name)
joeaortiz marked this conversation as resolved.
Show resolved Hide resolved

def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
the_copy = self.copy()
memo[id(self)] = the_copy
return the_copy

def update(
self,
mean: Sequence[Manifold],
precision: torch.Tensor,
):
if len(mean) != len(self.mean):
raise ValueError(
f"Tried to update mean with sequence of different"
f"length to original mean sequence. Given: {len(mean)}. "
f"Expected: {len(self.mean)}"
)
for i in range(len(self.mean)):
self.mean[i].update(mean[i])

expected_shape = torch.Size([mean[0].shape[0], self.dof, self.dof])
if precision.shape != expected_shape:
raise ValueError(
f"Tried to update precision with data "
f"incompatible with original tensor shape. Given: {precision.shape}. "
f"Expected: {expected_shape}"
)
if precision.dtype != self.dtype:
raise ValueError(
f"Tried to update using tensor of dtype: {precision.dtype} but precision "
f"has dtype: {self.dtype}."
)
if precision.device != self.device:
raise ValueError(
f"Tried to update using tensor on device: {precision.dtype} but precision "
f"is on device: {self.device}."
)
if not torch.allclose(precision, precision.transpose(1, 2)):
joeaortiz marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("Tried to update precision with non-symmetric matrix.")

self.precision = precision


def local_gaussian(
mhmukadam marked this conversation as resolved.
Show resolved Hide resolved
joeaortiz marked this conversation as resolved.
Show resolved Hide resolved
variable: LieGroup,
gaussian: ManifoldGaussian,
return_mean: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
# assumes gaussian is over just one Manifold object
if len(gaussian.mean) != 1:
raise ValueError(
"local on manifold should be over just one Manifold object. "
joeaortiz marked this conversation as resolved.
Show resolved Hide resolved
f"Passed gaussian {gaussian.name} is over {len(gaussian.mean)} "
"Manifold objects."
)
# check variable and gaussian are of the same LieGroup class
if gaussian.mean[0].__class__ != variable.__class__:
raise ValueError(
"variable and gaussian mean must be instances of the same class. "
f"variable is of class {variable.__class__} and gaussian mean is "
f"of class {gaussian.mean[0].__class__}."
)

# mean vector in the tangent space at variable
mean_tp = variable.local(gaussian.mean[0])
joeaortiz marked this conversation as resolved.
Show resolved Hide resolved

jac: List[torch.Tensor] = []
variable.exp_map(mean_tp, jacobians=jac)
# precision matrix in the tangent space at variable
# Following math in section H https://arxiv.org/pdf/1812.01537.pdf
lam_tp = torch.bmm(torch.bmm(jac[0].transpose(-1, -2), gaussian.precision), jac[0])
joeaortiz marked this conversation as resolved.
Show resolved Hide resolved

if return_mean:
return mean_tp, lam_tp
else:
eta_tp = torch.matmul(lam_tp, mean_tp.unsqueeze(-1)).squeeze(-1)
return eta_tp, lam_tp


def retract_gaussian(
variable: LieGroup,
mean_tp: torch.Tensor,
precision_tp: torch.Tensor,
) -> ManifoldGaussian:
mean = variable.retract(mean_tp)

jac: List[torch.Tensor] = []
variable.exp_map(mean_tp, jacobians=jac)
inv_jac = torch.inverse(jac[0])
precision = torch.bmm(torch.bmm(inv_jac.transpose(-1, -2), precision_tp), inv_jac)

return ManifoldGaussian(mean=[mean], precision=precision)
Loading