forked from facebookresearch/theseus
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ManifoldGaussian class for messages in belief propagation (facebo…
…okresearch#121) * Gaussian class to wrap Manifold class and lam matrix for inverse covariance * reformatted * restored original manifold file * initial attempt at marginal class , need to handle batch dim * added standard fns dtype, to, copy, update * single to call in init * renamed ManifoldGaussian * setting precision in init with checks * update function requires mean and precision * fixed naming in init * manifold gaussian tests * retract and local gaussian fns * check precision is a symmetric matrix * moved retract and local gaussian to manifold_gaussian to avoid circular imports * added ManifoldGaussian to inits * minor edits * fixed dtype error in se3 that appeared in unit tests * add checks for local_gaussian * tests for local and retract gaussian * import from th. * added local_gaussian retract_gaussian to init, minor fix * minor fix on value error * fixed copy bug and added comments * random precision matrix in unit tests * fix for random precision * init precision with identity * fixed typo
- Loading branch information
Showing
4 changed files
with
378 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from itertools import count | ||
from typing import List, Optional, Sequence, Tuple | ||
|
||
import torch | ||
|
||
from theseus.geometry import LieGroup, Manifold | ||
|
||
|
||
class ManifoldGaussian: | ||
_ids = count(0) | ||
|
||
def __init__( | ||
self, | ||
mean: Sequence[Manifold], | ||
precision: Optional[torch.Tensor] = None, | ||
name: Optional[str] = None, | ||
): | ||
self._id = next(ManifoldGaussian._ids) | ||
if name is None: | ||
name = f"{self.__class__.__name__}__{self._id}" | ||
self.name = name | ||
|
||
dof = 0 | ||
for v in mean: | ||
dof += v.dof() | ||
self._dof = dof | ||
|
||
self.mean = mean | ||
if precision is None: | ||
precision = torch.eye(self.dof).to( | ||
dtype=mean[0].dtype, device=mean[0].device | ||
) | ||
precision = precision[None, ...].repeat(mean[0].shape[0], 1, 1) | ||
self.update(mean, precision) | ||
|
||
@property | ||
def dof(self) -> int: | ||
return self._dof | ||
|
||
@property | ||
def device(self) -> torch.device: | ||
return self.mean[0].device | ||
|
||
@property | ||
def dtype(self) -> torch.dtype: | ||
return self.mean[0].dtype | ||
|
||
# calls to() on the internal tensors | ||
def to(self, *args, **kwargs): | ||
for var in self.mean: | ||
var = var.to(*args, **kwargs) | ||
self.precision = self.precision.to(*args, **kwargs) | ||
|
||
def copy(self, new_name: Optional[str] = None) -> "ManifoldGaussian": | ||
if not new_name: | ||
new_name = f"{self.name}_copy" | ||
mean_copy = [var.copy() for var in self.mean] | ||
precision_copy = self.precision.clone() | ||
return ManifoldGaussian(mean_copy, precision=precision_copy, name=new_name) | ||
|
||
def __deepcopy__(self, memo): | ||
if id(self) in memo: | ||
return memo[id(self)] | ||
the_copy = self.copy() | ||
memo[id(self)] = the_copy | ||
return the_copy | ||
|
||
def update( | ||
self, | ||
mean: Sequence[Manifold], | ||
precision: torch.Tensor, | ||
): | ||
if len(mean) != len(self.mean): | ||
raise ValueError( | ||
f"Tried to update mean with sequence of different" | ||
f"length to original mean sequence. Given: {len(mean)}. " | ||
f"Expected: {len(self.mean)}" | ||
) | ||
for i in range(len(self.mean)): | ||
self.mean[i].update(mean[i]) | ||
|
||
expected_shape = torch.Size([mean[0].shape[0], self.dof, self.dof]) | ||
if precision.shape != expected_shape: | ||
raise ValueError( | ||
f"Tried to update precision with data " | ||
f"incompatible with original tensor shape. Given: {precision.shape}. " | ||
f"Expected: {expected_shape}" | ||
) | ||
if precision.dtype != self.dtype: | ||
raise ValueError( | ||
f"Tried to update using tensor of dtype: {precision.dtype} but precision " | ||
f"has dtype: {self.dtype}." | ||
) | ||
if precision.device != self.device: | ||
raise ValueError( | ||
f"Tried to update using tensor on device: {precision.dtype} but precision " | ||
f"is on device: {self.device}." | ||
) | ||
if not torch.allclose(precision, precision.transpose(1, 2)): | ||
raise ValueError("Tried to update precision with non-symmetric matrix.") | ||
|
||
self.precision = precision | ||
|
||
|
||
# Projects the gaussian (ManifoldGaussian object) into the tangent plane at | ||
# variable. The gaussian mean is projected using the local function, | ||
# and the precision is approximately transformed using the jacobains of the exp_map. | ||
# Either returns the mean and precision of the new Gaussian in the tangent plane if | ||
# return_mean is True. Otherwise returns the information vector (eta) and precision. | ||
# See section H, eqn 55 in https://arxiv.org/pdf/1812.01537.pdf for a derivation | ||
# of covariance propagation in manifolds. | ||
def local_gaussian( | ||
variable: LieGroup, | ||
gaussian: ManifoldGaussian, | ||
return_mean: bool = True, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
# assumes gaussian is over just one Manifold object | ||
if len(gaussian.mean) != 1: | ||
raise ValueError( | ||
"local on ManifoldGaussian should be over just one Manifold object. " | ||
f"Passed gaussian {gaussian.name} is over {len(gaussian.mean)} " | ||
"Manifold objects." | ||
) | ||
# check variable and gaussian are of the same LieGroup class | ||
if gaussian.mean[0].__class__ != variable.__class__: | ||
raise ValueError( | ||
"variable and gaussian mean must be instances of the same class. " | ||
f"variable is of class {variable.__class__} and gaussian mean is " | ||
f"of class {gaussian.mean[0].__class__}." | ||
) | ||
|
||
# mean vector in the tangent space at variable | ||
mean_tp = variable.local(gaussian.mean[0]) | ||
|
||
jac: List[torch.Tensor] = [] | ||
variable.exp_map(mean_tp, jacobians=jac) | ||
# precision matrix in the tangent space at variable | ||
lam_tp = torch.bmm(torch.bmm(jac[0].transpose(-1, -2), gaussian.precision), jac[0]) | ||
|
||
if return_mean: | ||
return mean_tp, lam_tp | ||
else: | ||
eta_tp = torch.matmul(lam_tp, mean_tp.unsqueeze(-1)).squeeze(-1) | ||
return eta_tp, lam_tp | ||
|
||
|
||
# Computes the ManifoldGaussian that corresponds to the gaussian in the tangent plane | ||
# at variable, parameterised by the mean (mean_tp) and precision (precision_tp). | ||
# The mean is transformed to a LieGroup element by retraction. | ||
# The precision is transformed using the inverse of the exp_map jacobians. | ||
# See section H, eqn 55 in https://arxiv.org/pdf/1812.01537.pdf for a derivation | ||
# of covariance propagation in manifolds. | ||
def retract_gaussian( | ||
variable: LieGroup, | ||
mean_tp: torch.Tensor, | ||
precision_tp: torch.Tensor, | ||
) -> ManifoldGaussian: | ||
mean = variable.retract(mean_tp) | ||
|
||
jac: List[torch.Tensor] = [] | ||
variable.exp_map(mean_tp, jacobians=jac) | ||
inv_jac = torch.inverse(jac[0]) | ||
precision = torch.bmm(torch.bmm(inv_jac.transpose(-1, -2), precision_tp), inv_jac) | ||
|
||
return ManifoldGaussian(mean=[mean], precision=precision) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import copy | ||
|
||
import numpy as np | ||
import pytest # noqa: F401 | ||
import torch | ||
|
||
import theseus as th | ||
|
||
|
||
def random_manifold_gaussian_params(): | ||
manif_types = [th.Point2, th.Point3, th.SE2, th.SE3, th.SO2, th.SO3] | ||
|
||
n_vars = np.random.randint(1, 5) | ||
batch_size = np.random.randint(1, 100) | ||
mean = [] | ||
dof = 0 | ||
for i in range(n_vars): | ||
ix = np.random.randint(len(manif_types)) | ||
var = manif_types[ix].rand(batch_size) | ||
mean.append(var) | ||
dof += var.dof() | ||
|
||
if np.random.random() < 0.5: | ||
precision_sqrt = torch.rand(mean[0].shape[0], dof, dof) | ||
precision = torch.bmm(precision_sqrt, precision_sqrt.transpose(1, 2)) | ||
precision += torch.eye(dof)[None, ...].repeat(mean[0].shape[0], 1, 1) | ||
else: | ||
precision = None | ||
|
||
return mean, precision | ||
|
||
|
||
def test_init(): | ||
all_ids = [] | ||
for i in range(10): | ||
if np.random.random() < 0.5: | ||
name = f"name_{i}" | ||
else: | ||
name = None | ||
mean, precision = random_manifold_gaussian_params() | ||
dof = sum([v.dof() for v in mean]) | ||
n_vars = len(mean) | ||
|
||
t = th.ManifoldGaussian(mean, precision=precision, name=name) | ||
all_ids.append(t._id) | ||
if name is not None: | ||
assert name == t.name | ||
else: | ||
assert t.name == f"ManifoldGaussian__{t._id}" | ||
assert t.dof == dof | ||
for j in range(n_vars): | ||
assert t.mean[j] == mean[j] | ||
if precision is not None: | ||
assert torch.isclose(t.precision, precision).all() | ||
else: | ||
precision = torch.zeros(mean[0].shape[0], dof, dof).to( | ||
dtype=mean[0].dtype, device=mean[0].device | ||
) | ||
precision = torch.eye(dof).to(dtype=mean[0].dtype, device=mean[0].device) | ||
precision = precision[None, ...].repeat(mean[0].shape[0], 1, 1) | ||
assert torch.isclose(t.precision, precision).all() | ||
|
||
assert len(set(all_ids)) == len(all_ids) | ||
|
||
|
||
def test_to(): | ||
for i in range(10): | ||
mean, precision = random_manifold_gaussian_params() | ||
t = th.ManifoldGaussian(mean, precision=precision) | ||
dtype = torch.float64 if np.random.random() < 0.5 else torch.long | ||
t.to(dtype) | ||
|
||
for var in t.mean: | ||
assert var.dtype == dtype | ||
assert t.precision.dtype == dtype | ||
|
||
|
||
def test_copy(): | ||
for i in range(10): | ||
mean, precision = random_manifold_gaussian_params() | ||
n_vars = len(mean) | ||
var = th.ManifoldGaussian(mean, precision, name="var") | ||
|
||
var.name = "old" | ||
new_var = var.copy(new_name="new") | ||
assert var is not new_var | ||
for j in range(n_vars): | ||
assert var.mean[j] is not new_var.mean[j] | ||
assert var.precision is not new_var.precision | ||
assert torch.allclose(var.precision, new_var.precision) | ||
assert new_var.name == "new" | ||
new_var_no_name = copy.deepcopy(var) | ||
assert new_var_no_name.name == f"{var.name}_copy" | ||
|
||
|
||
def test_update(): | ||
for i in range(10): | ||
mean, precision = random_manifold_gaussian_params() | ||
n_vars = len(mean) | ||
dof = sum([v.dof() for v in mean]) | ||
batch_size = mean[0].shape[0] | ||
|
||
var = th.ManifoldGaussian(mean, precision, name="var") | ||
|
||
# check update | ||
new_mean_good = [] | ||
for j in range(n_vars): | ||
new_var = mean[j].__class__.rand(batch_size) | ||
new_mean_good.append(new_var) | ||
new_precision_good = torch.eye(dof)[None, ...].repeat(batch_size, 1, 1) | ||
|
||
var.update(new_mean_good, new_precision_good) | ||
|
||
assert var.precision is new_precision_good | ||
for j in range(n_vars): | ||
assert torch.allclose(var.mean[j].data, new_mean_good[j].data) | ||
|
||
# check raises error on shape for precision | ||
new_precision_bad = torch.eye(dof + 1)[None, ...].repeat(batch_size, 1, 1) | ||
with pytest.raises(ValueError): | ||
var.update(new_mean_good, new_precision_bad) | ||
|
||
# check raises error on dtype for precision | ||
new_precision_bad = torch.eye(dof)[None, ...].repeat(batch_size, 1, 1).double() | ||
with pytest.raises(ValueError): | ||
var.update(new_mean_good, new_precision_bad) | ||
|
||
# check raises error for non symmetric precision | ||
new_precision_bad = torch.eye(dof)[None, ...].repeat(batch_size, 1, 1) | ||
if dof > 1: | ||
new_precision_bad[0, 1, 0] += 1.0 | ||
with pytest.raises(ValueError): | ||
var.update(new_mean_good, new_precision_bad) | ||
|
||
# check raises error on wrong number of mean variables | ||
new_mean_bad = new_mean_good[:-1] | ||
with pytest.raises(ValueError): | ||
var.update(new_mean_bad, new_precision_good) | ||
|
||
# check raises error on wrong variable type | ||
new_mean_bad = new_mean_good | ||
new_mean_bad[-1] = th.Vector(10) | ||
with pytest.raises(ValueError): | ||
var.update(new_mean_bad, new_precision_good) | ||
|
||
|
||
def test_local_gaussian(): | ||
manif_types = [th.Point2, th.Point3, th.SE2, th.SE3, th.SO2, th.SO3] | ||
|
||
for i in range(50): | ||
batch_size = np.random.randint(1, 100) | ||
ix = np.random.randint(len(manif_types)) | ||
mean = [manif_types[ix].rand(batch_size)] | ||
precision = torch.eye(mean[0].dof())[None, ...].repeat(batch_size, 1, 1) | ||
gaussian = th.ManifoldGaussian(mean, precision) | ||
variable = manif_types[ix].rand(batch_size) | ||
|
||
mean_tp, lam_tp1 = th.local_gaussian(variable, gaussian, return_mean=True) | ||
eta_tp, lam_tp2 = th.local_gaussian(variable, gaussian, return_mean=False) | ||
|
||
assert torch.allclose(lam_tp1, lam_tp2) | ||
|
||
# check mean and eta are consistent | ||
mean_tp_calc = torch.matmul(lam_tp1, mean_tp.unsqueeze(-1)).squeeze(-1) | ||
assert torch.allclose(eta_tp, mean_tp_calc) | ||
|
||
# check raises error if gaussian over mulitple Manifold objects | ||
bad_mean = mean + [variable] | ||
dof = sum([var.dof() for var in bad_mean]) | ||
precision = torch.zeros(batch_size, dof, dof) | ||
bad_gaussian = th.ManifoldGaussian(bad_mean, precision) | ||
with pytest.raises(ValueError): | ||
_, _ = th.local_gaussian(variable, bad_gaussian, return_mean=True) | ||
|
||
# check raises error if gaussian over mulitple Manifold objects | ||
bad_ix = np.mod(ix + 1, len(manif_types)) | ||
bad_variable = manif_types[bad_ix].rand(batch_size) | ||
with pytest.raises(ValueError): | ||
_, _ = th.local_gaussian(bad_variable, gaussian, return_mean=True) | ||
|
||
|
||
def test_retract_gaussian(): | ||
manif_types = [th.Point2, th.Point3, th.SE2, th.SE3, th.SO2, th.SO3] | ||
|
||
for i in range(50): | ||
batch_size = np.random.randint(1, 100) | ||
ix = np.random.randint(len(manif_types)) | ||
variable = manif_types[ix].rand(batch_size) | ||
|
||
mean_tp = torch.rand(batch_size, variable.dof()) | ||
lam_tp = torch.eye(variable.dof())[None, ...].repeat(batch_size, 1, 1) | ||
|
||
gaussian = th.retract_gaussian(variable, mean_tp, lam_tp) | ||
assert torch.allclose(gaussian.mean[0].data, variable.retract(mean_tp).data) |