Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ManifoldGaussian class for messages in belief propagation #121

Merged
merged 30 commits into from
Apr 20, 2022
Merged
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
78d991e
Gaussian class to wrap Manifold class and lam matrix for inverse cova…
joeaortiz Mar 16, 2022
978e150
reformatted
joeaortiz Mar 16, 2022
5e31531
restored original manifold file
joeaortiz Apr 4, 2022
10389ec
Merge branch 'main' into joe.add_manifold_covariance
joeaortiz Apr 4, 2022
312bd38
initial attempt at marginal class , need to handle batch dim
joeaortiz Apr 4, 2022
90572c8
Merge branch 'main' into joe.add_manifold_covariance
joeaortiz Apr 5, 2022
19c9d34
added standard fns dtype, to, copy, update
joeaortiz Apr 7, 2022
167bb84
single to call in init
joeaortiz Apr 7, 2022
572486a
renamed ManifoldGaussian
joeaortiz Apr 7, 2022
4cc5f78
setting precision in init with checks
joeaortiz Apr 8, 2022
a8ecb77
update function requires mean and precision
joeaortiz Apr 8, 2022
23cc2f9
fixed naming in init
joeaortiz Apr 11, 2022
00c7414
manifold gaussian tests
joeaortiz Apr 11, 2022
212a5d0
retract and local gaussian fns
joeaortiz Apr 11, 2022
e364a84
check precision is a symmetric matrix
joeaortiz Apr 11, 2022
58245a0
moved retract and local gaussian to manifold_gaussian to avoid circul…
joeaortiz Apr 12, 2022
0e9e93a
added ManifoldGaussian to inits
joeaortiz Apr 13, 2022
b9b7bbf
minor edits
joeaortiz Apr 13, 2022
b3b867e
fixed dtype error in se3 that appeared in unit tests
joeaortiz Apr 13, 2022
bce0e77
add checks for local_gaussian
joeaortiz Apr 13, 2022
d507742
tests for local and retract gaussian
joeaortiz Apr 13, 2022
eecf1a7
import from th.
joeaortiz Apr 13, 2022
6812d9b
added local_gaussian retract_gaussian to init, minor fix
joeaortiz Apr 13, 2022
8a97d10
minor fix on value error
joeaortiz Apr 13, 2022
218d754
Merge branch main into joe.add_manifold_covariance
joeaortiz Apr 19, 2022
8ac7967
fixed copy bug and added comments
joeaortiz Apr 19, 2022
e19a8d7
random precision matrix in unit tests
joeaortiz Apr 19, 2022
4e98aff
fix for random precision
joeaortiz Apr 19, 2022
7528e60
init precision with identity
joeaortiz Apr 19, 2022
6a7b1a4
fixed typo
joeaortiz Apr 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions theseus/optimizer/manifold_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,12 @@ def update(
for i in range(len(self.mean)):
self.mean[i].update(mean[i])

if precision.shape != torch.Size([mean[0].shape[0], self.dof, self.dof]):
expected_shape = torch.Size([mean[0].shape[0], self.dof, self.dof])
if precision.shape != expected_shape:
raise ValueError(
f"Tried to update precision with data "
f"incompatible with original tensor shape. Given: {precision.shape}. "
f"Expected: {self.precision.shape}"
f"Expected: {expected_shape}"
)
if precision.dtype != self.dtype:
raise ValueError(
Expand All @@ -109,6 +110,21 @@ def local_gaussian(
gaussian: ManifoldGaussian,
return_mean: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
# assumes gaussian is over just one Manifold object
if len(gaussian.mean) != 1:
raise ValueError(
"ManifoldGaussian should be over just one Manifold object. "
joeaortiz marked this conversation as resolved.
Show resolved Hide resolved
f"Passed gaussian {gaussian.name} is over {len(gaussian.mean)} "
"Manifold objects."
)
# check variable and gaussian are of the same LieGroup class
if gaussian.mean[0].__class__ != variable.__class__:
raise ValueError(
"variable and gaussian mean must be instances of the same class. "
f"variable is of class {variable.__class__} and gaussian mean is "
f"of class {gaussian.mean[0].__class__}."
)

# mean vector in the tangent space at variable
mean_tp = variable.local(gaussian.mean[0])
joeaortiz marked this conversation as resolved.
Show resolved Hide resolved

Expand Down