Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ManifoldGaussian class for messages in belief propagation #121

Merged
merged 30 commits into from
Apr 20, 2022
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
78d991e
Gaussian class to wrap Manifold class and lam matrix for inverse cova…
joeaortiz Mar 16, 2022
978e150
reformatted
joeaortiz Mar 16, 2022
5e31531
restored original manifold file
joeaortiz Apr 4, 2022
10389ec
Merge branch 'main' into joe.add_manifold_covariance
joeaortiz Apr 4, 2022
312bd38
initial attempt at marginal class , need to handle batch dim
joeaortiz Apr 4, 2022
90572c8
Merge branch 'main' into joe.add_manifold_covariance
joeaortiz Apr 5, 2022
19c9d34
added standard fns dtype, to, copy, update
joeaortiz Apr 7, 2022
167bb84
single to call in init
joeaortiz Apr 7, 2022
572486a
renamed ManifoldGaussian
joeaortiz Apr 7, 2022
4cc5f78
setting precision in init with checks
joeaortiz Apr 8, 2022
a8ecb77
update function requires mean and precision
joeaortiz Apr 8, 2022
23cc2f9
fixed naming in init
joeaortiz Apr 11, 2022
00c7414
manifold gaussian tests
joeaortiz Apr 11, 2022
212a5d0
retract and local gaussian fns
joeaortiz Apr 11, 2022
e364a84
check precision is a symmetric matrix
joeaortiz Apr 11, 2022
58245a0
moved retract and local gaussian to manifold_gaussian to avoid circul…
joeaortiz Apr 12, 2022
0e9e93a
added ManifoldGaussian to inits
joeaortiz Apr 13, 2022
b9b7bbf
minor edits
joeaortiz Apr 13, 2022
b3b867e
fixed dtype error in se3 that appeared in unit tests
joeaortiz Apr 13, 2022
bce0e77
add checks for local_gaussian
joeaortiz Apr 13, 2022
d507742
tests for local and retract gaussian
joeaortiz Apr 13, 2022
eecf1a7
import from th.
joeaortiz Apr 13, 2022
6812d9b
added local_gaussian retract_gaussian to init, minor fix
joeaortiz Apr 13, 2022
8a97d10
minor fix on value error
joeaortiz Apr 13, 2022
218d754
Merge branch main into joe.add_manifold_covariance
joeaortiz Apr 19, 2022
8ac7967
fixed copy bug and added comments
joeaortiz Apr 19, 2022
e19a8d7
random precision matrix in unit tests
joeaortiz Apr 19, 2022
4e98aff
fix for random precision
joeaortiz Apr 19, 2022
7528e60
init precision with identity
joeaortiz Apr 19, 2022
6a7b1a4
fixed typo
joeaortiz Apr 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion theseus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,14 @@
randn_se2,
randn_se3,
)
from .optimizer import DenseLinearization, SparseLinearization, VariableOrdering
from .optimizer import (
DenseLinearization,
SparseLinearization,
VariableOrdering,
ManifoldGaussian,
local_gaussian,
retract_gaussian,
)
from .optimizer.linear import (
CholeskyDenseSolver,
CholmodSparseSolver,
Expand Down
1 change: 1 addition & 0 deletions theseus/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .dense_linearization import DenseLinearization
from .linearization import Linearization
from .manifold_gaussian import ManifoldGaussian, local_gaussian, retract_gaussian
from .optimizer import Optimizer, OptimizerInfo
from .sparse_linearization import SparseLinearization
from .variable_ordering import VariableOrdering
170 changes: 170 additions & 0 deletions theseus/optimizer/manifold_gaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from itertools import count
from typing import List, Optional, Sequence, Tuple

import torch

from theseus.geometry import LieGroup, Manifold


class ManifoldGaussian:
_ids = count(0)

def __init__(
self,
mean: Sequence[Manifold],
precision: Optional[torch.Tensor] = None,
name: Optional[str] = None,
):
self._id = next(ManifoldGaussian._ids)
if name is None:
name = f"{self.__class__.__name__}__{self._id}"
self.name = name
joeaortiz marked this conversation as resolved.
Show resolved Hide resolved

dof = 0
for v in mean:
dof += v.dof()
self._dof = dof

self.mean = mean
if precision is None:
precision = torch.eye(self.dof).to(
dtype=mean[0].dtype, device=mean[0].device
)
precision = precision[None, ...].repeat(mean[0].shape[0], 1, 1)
self.update(mean, precision)
mhmukadam marked this conversation as resolved.
Show resolved Hide resolved

@property
def dof(self) -> int:
return self._dof

@property
def device(self) -> torch.device:
return self.mean[0].device

@property
def dtype(self) -> torch.dtype:
return self.mean[0].dtype

# calls to() on the internal tensors
def to(self, *args, **kwargs):
for var in self.mean:
var = var.to(*args, **kwargs)
self.precision = self.precision.to(*args, **kwargs)

def copy(self, new_name: Optional[str] = None) -> "ManifoldGaussian":
if not new_name:
new_name = f"{self.name}_copy"
mean_copy = [var.copy() for var in self.mean]
precision_copy = self.precision.clone()
return ManifoldGaussian(mean_copy, precision=precision_copy, name=new_name)

def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
the_copy = self.copy()
memo[id(self)] = the_copy
return the_copy

def update(
self,
mean: Sequence[Manifold],
precision: torch.Tensor,
):
if len(mean) != len(self.mean):
raise ValueError(
f"Tried to update mean with sequence of different"
f"length to original mean sequence. Given: {len(mean)}. "
f"Expected: {len(self.mean)}"
)
for i in range(len(self.mean)):
self.mean[i].update(mean[i])

expected_shape = torch.Size([mean[0].shape[0], self.dof, self.dof])
if precision.shape != expected_shape:
raise ValueError(
f"Tried to update precision with data "
f"incompatible with original tensor shape. Given: {precision.shape}. "
f"Expected: {expected_shape}"
)
if precision.dtype != self.dtype:
raise ValueError(
f"Tried to update using tensor of dtype: {precision.dtype} but precision "
f"has dtype: {self.dtype}."
)
if precision.device != self.device:
raise ValueError(
f"Tried to update using tensor on device: {precision.dtype} but precision "
f"is on device: {self.device}."
)
if not torch.allclose(precision, precision.transpose(1, 2)):
joeaortiz marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("Tried to update precision with non-symmetric matrix.")

self.precision = precision


# Projects the gaussian (ManifoldGaussian object) into the tangent plane at
# variable. The gaussian mean is projected using the local function,
# and the precision is approximately transformed using the jacobains of the exp_map.
# Either returns the mean and precision of the new Gaussian in the tangent plane if
# return_mean is True. Otherwise returns the information vector (eta) and precision.
# See section H, eqn 55 in https://arxiv.org/pdf/1812.01537.pdf for a derivation
# of covariance propagation in manifolds.
def local_gaussian(
mhmukadam marked this conversation as resolved.
Show resolved Hide resolved
joeaortiz marked this conversation as resolved.
Show resolved Hide resolved
variable: LieGroup,
gaussian: ManifoldGaussian,
return_mean: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
# assumes gaussian is over just one Manifold object
if len(gaussian.mean) != 1:
raise ValueError(
"local on ManifoldGaussian should be over just one Manifold object. "
f"Passed gaussian {gaussian.name} is over {len(gaussian.mean)} "
"Manifold objects."
)
# check variable and gaussian are of the same LieGroup class
if gaussian.mean[0].__class__ != variable.__class__:
raise ValueError(
"variable and gaussian mean must be instances of the same class. "
f"variable is of class {variable.__class__} and gaussian mean is "
f"of class {gaussian.mean[0].__class__}."
)

# mean vector in the tangent space at variable
mean_tp = variable.local(gaussian.mean[0])
joeaortiz marked this conversation as resolved.
Show resolved Hide resolved

jac: List[torch.Tensor] = []
variable.exp_map(mean_tp, jacobians=jac)
# precision matrix in the tangent space at variable
lam_tp = torch.bmm(torch.bmm(jac[0].transpose(-1, -2), gaussian.precision), jac[0])
joeaortiz marked this conversation as resolved.
Show resolved Hide resolved

if return_mean:
return mean_tp, lam_tp
else:
eta_tp = torch.matmul(lam_tp, mean_tp.unsqueeze(-1)).squeeze(-1)
return eta_tp, lam_tp


# Computes the ManifoldGaussian that corresponds to the gaussian in the tangent plane
# at variable, parameterised by hte mean (mean_tp) and precision (precision_tp).
joeaortiz marked this conversation as resolved.
Show resolved Hide resolved
# The mean is transformed to a LieGroup element by retraction.
# The precision is transformed using the inverse of the exp_map jacobians.
# See section H, eqn 55 in https://arxiv.org/pdf/1812.01537.pdf for a derivation
# of covariance propagation in manifolds.
def retract_gaussian(
variable: LieGroup,
mean_tp: torch.Tensor,
precision_tp: torch.Tensor,
) -> ManifoldGaussian:
mean = variable.retract(mean_tp)

jac: List[torch.Tensor] = []
variable.exp_map(mean_tp, jacobians=jac)
inv_jac = torch.inverse(jac[0])
precision = torch.bmm(torch.bmm(inv_jac.transpose(-1, -2), precision_tp), inv_jac)

return ManifoldGaussian(mean=[mean], precision=precision)
199 changes: 199 additions & 0 deletions theseus/optimizer/tests/test_manifold_gaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import copy

import numpy as np
import pytest # noqa: F401
import torch

import theseus as th


def random_manifold_gaussian_params():
manif_types = [th.Point2, th.Point3, th.SE2, th.SE3, th.SO2, th.SO3]

n_vars = np.random.randint(1, 5)
batch_size = np.random.randint(1, 100)
mean = []
dof = 0
for i in range(n_vars):
ix = np.random.randint(len(manif_types))
var = manif_types[ix].rand(batch_size)
mean.append(var)
dof += var.dof()

if np.random.random() < 0.5:
precision_sqrt = torch.rand(mean[0].shape[0], dof, dof)
precision = torch.bmm(precision_sqrt, precision_sqrt.transpose(1, 2))
precision += torch.eye(dof)[None, ...].repeat(mean[0].shape[0], 1, 1)
else:
precision = None

return mean, precision


def test_init():
all_ids = []
for i in range(10):
if np.random.random() < 0.5:
name = f"name_{i}"
else:
name = None
mean, precision = random_manifold_gaussian_params()
dof = sum([v.dof() for v in mean])
n_vars = len(mean)

t = th.ManifoldGaussian(mean, precision=precision, name=name)
all_ids.append(t._id)
if name is not None:
assert name == t.name
else:
assert t.name == f"ManifoldGaussian__{t._id}"
assert t.dof == dof
for j in range(n_vars):
assert t.mean[j] == mean[j]
if precision is not None:
assert torch.isclose(t.precision, precision).all()
else:
precision = torch.zeros(mean[0].shape[0], dof, dof).to(
dtype=mean[0].dtype, device=mean[0].device
)
precision = torch.eye(dof).to(dtype=mean[0].dtype, device=mean[0].device)
precision = precision[None, ...].repeat(mean[0].shape[0], 1, 1)
assert torch.isclose(t.precision, precision).all()

assert len(set(all_ids)) == len(all_ids)


def test_to():
for i in range(10):
mean, precision = random_manifold_gaussian_params()
t = th.ManifoldGaussian(mean, precision=precision)
dtype = torch.float64 if np.random.random() < 0.5 else torch.long
t.to(dtype)

for var in t.mean:
assert var.dtype == dtype
assert t.precision.dtype == dtype


def test_copy():
for i in range(10):
mean, precision = random_manifold_gaussian_params()
n_vars = len(mean)
var = th.ManifoldGaussian(mean, precision, name="var")

var.name = "old"
new_var = var.copy(new_name="new")
assert var is not new_var
for j in range(n_vars):
assert var.mean[j] is not new_var.mean[j]
assert var.precision is not new_var.precision
assert torch.allclose(var.precision, new_var.precision)
assert new_var.name == "new"
new_var_no_name = copy.deepcopy(var)
assert new_var_no_name.name == f"{var.name}_copy"


def test_update():
for i in range(10):
mean, precision = random_manifold_gaussian_params()
n_vars = len(mean)
dof = sum([v.dof() for v in mean])
batch_size = mean[0].shape[0]

var = th.ManifoldGaussian(mean, precision, name="var")

# check update
new_mean_good = []
for j in range(n_vars):
new_var = mean[j].__class__.rand(batch_size)
new_mean_good.append(new_var)
new_precision_good = torch.eye(dof)[None, ...].repeat(batch_size, 1, 1)

var.update(new_mean_good, new_precision_good)

assert var.precision is new_precision_good
for j in range(n_vars):
assert torch.allclose(var.mean[j].data, new_mean_good[j].data)

# check raises error on shape for precision
new_precision_bad = torch.eye(dof + 1)[None, ...].repeat(batch_size, 1, 1)
with pytest.raises(ValueError):
var.update(new_mean_good, new_precision_bad)

# check raises error on dtype for precision
new_precision_bad = torch.eye(dof)[None, ...].repeat(batch_size, 1, 1).double()
with pytest.raises(ValueError):
var.update(new_mean_good, new_precision_bad)

# check raises error for non symmetric precision
new_precision_bad = torch.eye(dof)[None, ...].repeat(batch_size, 1, 1)
if dof > 1:
new_precision_bad[0, 1, 0] += 1.0
with pytest.raises(ValueError):
var.update(new_mean_good, new_precision_bad)

# check raises error on wrong number of mean variables
new_mean_bad = new_mean_good[:-1]
with pytest.raises(ValueError):
var.update(new_mean_bad, new_precision_good)

# check raises error on wrong variable type
new_mean_bad = new_mean_good
new_mean_bad[-1] = th.Vector(10)
with pytest.raises(ValueError):
var.update(new_mean_bad, new_precision_good)


def test_local_gaussian():
manif_types = [th.Point2, th.Point3, th.SE2, th.SE3, th.SO2, th.SO3]

for i in range(50):
batch_size = np.random.randint(1, 100)
ix = np.random.randint(len(manif_types))
mean = [manif_types[ix].rand(batch_size)]
precision = torch.eye(mean[0].dof())[None, ...].repeat(batch_size, 1, 1)
gaussian = th.ManifoldGaussian(mean, precision)
variable = manif_types[ix].rand(batch_size)

mean_tp, lam_tp1 = th.local_gaussian(variable, gaussian, return_mean=True)
eta_tp, lam_tp2 = th.local_gaussian(variable, gaussian, return_mean=False)

assert torch.allclose(lam_tp1, lam_tp2)

# check mean and eta are consistent
mean_tp_calc = torch.matmul(lam_tp1, mean_tp.unsqueeze(-1)).squeeze(-1)
assert torch.allclose(eta_tp, mean_tp_calc)

# check raises error if gaussian over mulitple Manifold objects
bad_mean = mean + [variable]
dof = sum([var.dof() for var in bad_mean])
precision = torch.zeros(batch_size, dof, dof)
bad_gaussian = th.ManifoldGaussian(bad_mean, precision)
with pytest.raises(ValueError):
_, _ = th.local_gaussian(variable, bad_gaussian, return_mean=True)

# check raises error if gaussian over mulitple Manifold objects
bad_ix = np.mod(ix + 1, len(manif_types))
bad_variable = manif_types[bad_ix].rand(batch_size)
with pytest.raises(ValueError):
_, _ = th.local_gaussian(bad_variable, gaussian, return_mean=True)


def test_retract_gaussian():
manif_types = [th.Point2, th.Point3, th.SE2, th.SE3, th.SO2, th.SO3]

for i in range(50):
batch_size = np.random.randint(1, 100)
ix = np.random.randint(len(manif_types))
variable = manif_types[ix].rand(batch_size)

mean_tp = torch.rand(batch_size, variable.dof())
lam_tp = torch.eye(variable.dof())[None, ...].repeat(batch_size, 1, 1)

gaussian = th.retract_gaussian(variable, mean_tp, lam_tp)
assert torch.allclose(gaussian.mean[0].data, variable.retract(mean_tp).data)