Skip to content

Commit

Permalink
Add ManifoldGaussian class for messages in belief propagation (facebo…
Browse files Browse the repository at this point in the history
…okresearch#121)

* Gaussian class to wrap Manifold class and lam matrix for inverse covariance

* reformatted

* restored original manifold file

* initial attempt at marginal class , need to handle batch dim

* added standard fns dtype, to, copy, update

* single to call in init

* renamed ManifoldGaussian

* setting precision in init with checks

* update function requires mean and precision

* fixed naming in init

* manifold gaussian tests

* retract and local gaussian fns

* check precision is a symmetric matrix

* moved retract and local gaussian to manifold_gaussian to avoid circular imports

* added ManifoldGaussian to inits

* minor edits

* fixed dtype error in se3 that appeared in unit tests

* add checks for local_gaussian

* tests for local and retract gaussian

* import from th.

* added local_gaussian retract_gaussian to init, minor fix

* minor fix on value error

* fixed copy bug and added comments

* random precision matrix in unit tests

* fix for random precision

* init precision with identity

* fixed typo
  • Loading branch information
joeaortiz authored Apr 20, 2022
1 parent dc5d931 commit 4b75f36
Show file tree
Hide file tree
Showing 4 changed files with 378 additions and 1 deletion.
9 changes: 8 additions & 1 deletion theseus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,14 @@
randn_se2,
randn_se3,
)
from .optimizer import DenseLinearization, SparseLinearization, VariableOrdering
from .optimizer import (
DenseLinearization,
SparseLinearization,
VariableOrdering,
ManifoldGaussian,
local_gaussian,
retract_gaussian,
)
from .optimizer.linear import (
CholeskyDenseSolver,
CholmodSparseSolver,
Expand Down
1 change: 1 addition & 0 deletions theseus/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .dense_linearization import DenseLinearization
from .linearization import Linearization
from .manifold_gaussian import ManifoldGaussian, local_gaussian, retract_gaussian
from .optimizer import Optimizer, OptimizerInfo
from .sparse_linearization import SparseLinearization
from .variable_ordering import VariableOrdering
170 changes: 170 additions & 0 deletions theseus/optimizer/manifold_gaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from itertools import count
from typing import List, Optional, Sequence, Tuple

import torch

from theseus.geometry import LieGroup, Manifold


class ManifoldGaussian:
_ids = count(0)

def __init__(
self,
mean: Sequence[Manifold],
precision: Optional[torch.Tensor] = None,
name: Optional[str] = None,
):
self._id = next(ManifoldGaussian._ids)
if name is None:
name = f"{self.__class__.__name__}__{self._id}"
self.name = name

dof = 0
for v in mean:
dof += v.dof()
self._dof = dof

self.mean = mean
if precision is None:
precision = torch.eye(self.dof).to(
dtype=mean[0].dtype, device=mean[0].device
)
precision = precision[None, ...].repeat(mean[0].shape[0], 1, 1)
self.update(mean, precision)

@property
def dof(self) -> int:
return self._dof

@property
def device(self) -> torch.device:
return self.mean[0].device

@property
def dtype(self) -> torch.dtype:
return self.mean[0].dtype

# calls to() on the internal tensors
def to(self, *args, **kwargs):
for var in self.mean:
var = var.to(*args, **kwargs)
self.precision = self.precision.to(*args, **kwargs)

def copy(self, new_name: Optional[str] = None) -> "ManifoldGaussian":
if not new_name:
new_name = f"{self.name}_copy"
mean_copy = [var.copy() for var in self.mean]
precision_copy = self.precision.clone()
return ManifoldGaussian(mean_copy, precision=precision_copy, name=new_name)

def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
the_copy = self.copy()
memo[id(self)] = the_copy
return the_copy

def update(
self,
mean: Sequence[Manifold],
precision: torch.Tensor,
):
if len(mean) != len(self.mean):
raise ValueError(
f"Tried to update mean with sequence of different"
f"length to original mean sequence. Given: {len(mean)}. "
f"Expected: {len(self.mean)}"
)
for i in range(len(self.mean)):
self.mean[i].update(mean[i])

expected_shape = torch.Size([mean[0].shape[0], self.dof, self.dof])
if precision.shape != expected_shape:
raise ValueError(
f"Tried to update precision with data "
f"incompatible with original tensor shape. Given: {precision.shape}. "
f"Expected: {expected_shape}"
)
if precision.dtype != self.dtype:
raise ValueError(
f"Tried to update using tensor of dtype: {precision.dtype} but precision "
f"has dtype: {self.dtype}."
)
if precision.device != self.device:
raise ValueError(
f"Tried to update using tensor on device: {precision.dtype} but precision "
f"is on device: {self.device}."
)
if not torch.allclose(precision, precision.transpose(1, 2)):
raise ValueError("Tried to update precision with non-symmetric matrix.")

self.precision = precision


# Projects the gaussian (ManifoldGaussian object) into the tangent plane at
# variable. The gaussian mean is projected using the local function,
# and the precision is approximately transformed using the jacobains of the exp_map.
# Either returns the mean and precision of the new Gaussian in the tangent plane if
# return_mean is True. Otherwise returns the information vector (eta) and precision.
# See section H, eqn 55 in https://arxiv.org/pdf/1812.01537.pdf for a derivation
# of covariance propagation in manifolds.
def local_gaussian(
variable: LieGroup,
gaussian: ManifoldGaussian,
return_mean: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
# assumes gaussian is over just one Manifold object
if len(gaussian.mean) != 1:
raise ValueError(
"local on ManifoldGaussian should be over just one Manifold object. "
f"Passed gaussian {gaussian.name} is over {len(gaussian.mean)} "
"Manifold objects."
)
# check variable and gaussian are of the same LieGroup class
if gaussian.mean[0].__class__ != variable.__class__:
raise ValueError(
"variable and gaussian mean must be instances of the same class. "
f"variable is of class {variable.__class__} and gaussian mean is "
f"of class {gaussian.mean[0].__class__}."
)

# mean vector in the tangent space at variable
mean_tp = variable.local(gaussian.mean[0])

jac: List[torch.Tensor] = []
variable.exp_map(mean_tp, jacobians=jac)
# precision matrix in the tangent space at variable
lam_tp = torch.bmm(torch.bmm(jac[0].transpose(-1, -2), gaussian.precision), jac[0])

if return_mean:
return mean_tp, lam_tp
else:
eta_tp = torch.matmul(lam_tp, mean_tp.unsqueeze(-1)).squeeze(-1)
return eta_tp, lam_tp


# Computes the ManifoldGaussian that corresponds to the gaussian in the tangent plane
# at variable, parameterised by the mean (mean_tp) and precision (precision_tp).
# The mean is transformed to a LieGroup element by retraction.
# The precision is transformed using the inverse of the exp_map jacobians.
# See section H, eqn 55 in https://arxiv.org/pdf/1812.01537.pdf for a derivation
# of covariance propagation in manifolds.
def retract_gaussian(
variable: LieGroup,
mean_tp: torch.Tensor,
precision_tp: torch.Tensor,
) -> ManifoldGaussian:
mean = variable.retract(mean_tp)

jac: List[torch.Tensor] = []
variable.exp_map(mean_tp, jacobians=jac)
inv_jac = torch.inverse(jac[0])
precision = torch.bmm(torch.bmm(inv_jac.transpose(-1, -2), precision_tp), inv_jac)

return ManifoldGaussian(mean=[mean], precision=precision)
199 changes: 199 additions & 0 deletions theseus/optimizer/tests/test_manifold_gaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import copy

import numpy as np
import pytest # noqa: F401
import torch

import theseus as th


def random_manifold_gaussian_params():
manif_types = [th.Point2, th.Point3, th.SE2, th.SE3, th.SO2, th.SO3]

n_vars = np.random.randint(1, 5)
batch_size = np.random.randint(1, 100)
mean = []
dof = 0
for i in range(n_vars):
ix = np.random.randint(len(manif_types))
var = manif_types[ix].rand(batch_size)
mean.append(var)
dof += var.dof()

if np.random.random() < 0.5:
precision_sqrt = torch.rand(mean[0].shape[0], dof, dof)
precision = torch.bmm(precision_sqrt, precision_sqrt.transpose(1, 2))
precision += torch.eye(dof)[None, ...].repeat(mean[0].shape[0], 1, 1)
else:
precision = None

return mean, precision


def test_init():
all_ids = []
for i in range(10):
if np.random.random() < 0.5:
name = f"name_{i}"
else:
name = None
mean, precision = random_manifold_gaussian_params()
dof = sum([v.dof() for v in mean])
n_vars = len(mean)

t = th.ManifoldGaussian(mean, precision=precision, name=name)
all_ids.append(t._id)
if name is not None:
assert name == t.name
else:
assert t.name == f"ManifoldGaussian__{t._id}"
assert t.dof == dof
for j in range(n_vars):
assert t.mean[j] == mean[j]
if precision is not None:
assert torch.isclose(t.precision, precision).all()
else:
precision = torch.zeros(mean[0].shape[0], dof, dof).to(
dtype=mean[0].dtype, device=mean[0].device
)
precision = torch.eye(dof).to(dtype=mean[0].dtype, device=mean[0].device)
precision = precision[None, ...].repeat(mean[0].shape[0], 1, 1)
assert torch.isclose(t.precision, precision).all()

assert len(set(all_ids)) == len(all_ids)


def test_to():
for i in range(10):
mean, precision = random_manifold_gaussian_params()
t = th.ManifoldGaussian(mean, precision=precision)
dtype = torch.float64 if np.random.random() < 0.5 else torch.long
t.to(dtype)

for var in t.mean:
assert var.dtype == dtype
assert t.precision.dtype == dtype


def test_copy():
for i in range(10):
mean, precision = random_manifold_gaussian_params()
n_vars = len(mean)
var = th.ManifoldGaussian(mean, precision, name="var")

var.name = "old"
new_var = var.copy(new_name="new")
assert var is not new_var
for j in range(n_vars):
assert var.mean[j] is not new_var.mean[j]
assert var.precision is not new_var.precision
assert torch.allclose(var.precision, new_var.precision)
assert new_var.name == "new"
new_var_no_name = copy.deepcopy(var)
assert new_var_no_name.name == f"{var.name}_copy"


def test_update():
for i in range(10):
mean, precision = random_manifold_gaussian_params()
n_vars = len(mean)
dof = sum([v.dof() for v in mean])
batch_size = mean[0].shape[0]

var = th.ManifoldGaussian(mean, precision, name="var")

# check update
new_mean_good = []
for j in range(n_vars):
new_var = mean[j].__class__.rand(batch_size)
new_mean_good.append(new_var)
new_precision_good = torch.eye(dof)[None, ...].repeat(batch_size, 1, 1)

var.update(new_mean_good, new_precision_good)

assert var.precision is new_precision_good
for j in range(n_vars):
assert torch.allclose(var.mean[j].data, new_mean_good[j].data)

# check raises error on shape for precision
new_precision_bad = torch.eye(dof + 1)[None, ...].repeat(batch_size, 1, 1)
with pytest.raises(ValueError):
var.update(new_mean_good, new_precision_bad)

# check raises error on dtype for precision
new_precision_bad = torch.eye(dof)[None, ...].repeat(batch_size, 1, 1).double()
with pytest.raises(ValueError):
var.update(new_mean_good, new_precision_bad)

# check raises error for non symmetric precision
new_precision_bad = torch.eye(dof)[None, ...].repeat(batch_size, 1, 1)
if dof > 1:
new_precision_bad[0, 1, 0] += 1.0
with pytest.raises(ValueError):
var.update(new_mean_good, new_precision_bad)

# check raises error on wrong number of mean variables
new_mean_bad = new_mean_good[:-1]
with pytest.raises(ValueError):
var.update(new_mean_bad, new_precision_good)

# check raises error on wrong variable type
new_mean_bad = new_mean_good
new_mean_bad[-1] = th.Vector(10)
with pytest.raises(ValueError):
var.update(new_mean_bad, new_precision_good)


def test_local_gaussian():
manif_types = [th.Point2, th.Point3, th.SE2, th.SE3, th.SO2, th.SO3]

for i in range(50):
batch_size = np.random.randint(1, 100)
ix = np.random.randint(len(manif_types))
mean = [manif_types[ix].rand(batch_size)]
precision = torch.eye(mean[0].dof())[None, ...].repeat(batch_size, 1, 1)
gaussian = th.ManifoldGaussian(mean, precision)
variable = manif_types[ix].rand(batch_size)

mean_tp, lam_tp1 = th.local_gaussian(variable, gaussian, return_mean=True)
eta_tp, lam_tp2 = th.local_gaussian(variable, gaussian, return_mean=False)

assert torch.allclose(lam_tp1, lam_tp2)

# check mean and eta are consistent
mean_tp_calc = torch.matmul(lam_tp1, mean_tp.unsqueeze(-1)).squeeze(-1)
assert torch.allclose(eta_tp, mean_tp_calc)

# check raises error if gaussian over mulitple Manifold objects
bad_mean = mean + [variable]
dof = sum([var.dof() for var in bad_mean])
precision = torch.zeros(batch_size, dof, dof)
bad_gaussian = th.ManifoldGaussian(bad_mean, precision)
with pytest.raises(ValueError):
_, _ = th.local_gaussian(variable, bad_gaussian, return_mean=True)

# check raises error if gaussian over mulitple Manifold objects
bad_ix = np.mod(ix + 1, len(manif_types))
bad_variable = manif_types[bad_ix].rand(batch_size)
with pytest.raises(ValueError):
_, _ = th.local_gaussian(bad_variable, gaussian, return_mean=True)


def test_retract_gaussian():
manif_types = [th.Point2, th.Point3, th.SE2, th.SE3, th.SO2, th.SO3]

for i in range(50):
batch_size = np.random.randint(1, 100)
ix = np.random.randint(len(manif_types))
variable = manif_types[ix].rand(batch_size)

mean_tp = torch.rand(batch_size, variable.dof())
lam_tp = torch.eye(variable.dof())[None, ...].repeat(batch_size, 1, 1)

gaussian = th.retract_gaussian(variable, mean_tp, lam_tp)
assert torch.allclose(gaussian.mean[0].data, variable.retract(mean_tp).data)

0 comments on commit 4b75f36

Please sign in to comment.