Skip to content

Commit

Permalink
Multi-batch unit tests for Lie groups (facebookresearch#522)
Browse files Browse the repository at this point in the history
* Removed unnecessary unsqueeze in quaternion to rotation op.

* Added batch sizes with dim !=1 to use for current unit tests (except logmap).

* Added a check for multi-batch op outputs to be consistent with flattened to single-batch-dim output.

* Added broadcasting test for compose.

* Added unit tests for binary op broadcasting.

* Added main multi-batch unit tests for log map. Fixed bug.

* Added one more multi-batch test case to main lie group unit tests.
  • Loading branch information
luisenp authored Jun 20, 2023
1 parent 016c1f7 commit aed9eb7
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 29 deletions.
152 changes: 132 additions & 20 deletions tests/labs/lie/functional/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from functools import reduce

import torch


BATCH_SIZES_TO_TEST = [1, 20]
BATCH_SIZES_TO_TEST = [1, 20, (1, 2), (3, 4, 5), tuple()]
TEST_EPS = 5e-7


Expand Down Expand Up @@ -56,57 +57,81 @@ def get_test_cfg(op_name, dtype, dim, data_shape, module=None):

# Sample inputs with the desired types.
# Input type is one of:
# ("tangent", dim) # torch.rand(batch_size, dim)
# ("group", module) # e.g., module.rand(batch_size)
# ("tangent", dim) # torch.rand(*batch_size, dim)
# ("group", module) # e.g., module.rand(*batch_size)
# ("quat", dim) # sampled like tangent but normalized
# ("matrix", shape) # torch.rand((batch_size,) + shape)
def sample_inputs(input_types, batch_size, dtype, rng, module=None):
# ("matrix", shape) # torch.rand((*batch_size,) + shape)
#
# `batch_size` can be a tuple.
def sample_inputs(input_types, batch_size, dtype, rng):
if isinstance(batch_size, int):
batch_size = (batch_size,)

def _sample(input_type):
type_str, param = input_type

def _quat_sample():
q = torch.rand(batch_size, param, dtype=dtype, generator=rng)
return q / torch.norm(q, dim=1, keepdim=True)
q = torch.rand(*batch_size, param, dtype=dtype, generator=rng)
return q / torch.norm(q, dim=-1, keepdim=True)

sample_fns = {
"tangent": lambda: torch.rand(
batch_size, param, dtype=dtype, generator=rng
*batch_size, param, dtype=dtype, generator=rng
),
"group": lambda: param.rand(batch_size, generator=rng, dtype=dtype),
"group": lambda: param.rand(*batch_size, generator=rng, dtype=dtype),
"quat": lambda: _quat_sample(),
"matrix": lambda: torch.rand(
(batch_size,) + param, generator=rng, dtype=dtype
(*batch_size,) + param, generator=rng, dtype=dtype
),
}
return sample_fns[type_str]()

return tuple(_sample(type_str) for type_str in input_types)


# Run the test for a Lie group operator
# Run some unit tests for a Lie group operator:
# checks:
# - jacobian of default torch autograd consistent with custom backward implementation
# - multi-batch output consistent with single-batch output
def run_test_op(op_name, batch_size, dtype, rng, dim, data_shape, module):
is_multi_batch = not isinstance(batch_size, int)
bs = len(batch_size) if is_multi_batch else 1
all_input_types, atol = get_test_cfg(op_name, dtype, dim, data_shape, module=module)
for input_types in all_input_types:
inputs = sample_inputs(input_types, batch_size, dtype, rng)
funcs = (
tuple(left_project_func(module, x) for x in inputs)
tuple(left_project_func(module, x, bs) for x in inputs)
if op_name == "log"
else None
)

# check analytic backward for the operator
check_lie_group_function(module, op_name, atol, inputs, funcs=funcs)
check_lie_group_function(
module,
op_name,
atol,
inputs,
funcs=funcs,
batch_size=batch_size if is_multi_batch else None,
)


# Checks if the jacobian computed by default torch autograd is close to the one
# provided with custom backward
# Checks:
#
# 1) if the jacobian computed by default torch autograd is close to the one
# provided with custom backward
# 2) if the output of op and jop is consistent with flattening all batch dims
# to a single dim.
# funcs is a list of callable that modifiies the jacobian. If provided we also
# check that func(jac_autograd) is close to func(jac_custom), for each func in
# the list
def check_lie_group_function(module, op_name: str, atol: float, inputs, funcs=None):
def check_lie_group_function(
module, op_name: str, atol: float, inputs, funcs=None, batch_size=None
):
op_impl = getattr(module, f"_{op_name}_impl")
op = getattr(module, f"_{op_name}_autograd_fn")
jop = getattr(module, f"_j{op_name}_autograd_fn")

# Check jacobians
jacs_impl = torch.autograd.functional.jacobian(op_impl, inputs, vectorize=True)
jacs = torch.autograd.functional.jacobian(op, inputs, vectorize=True)

Expand All @@ -122,17 +147,52 @@ def check_lie_group_function(module, op_name: str, atol: float, inputs, funcs=No
func(jac_impl), func(jac), atol=atol, rtol=atol
)

# Check multi-batch consistency
if batch_size is None:
return
lb = len(batch_size)
flattened_inputs = [x.reshape(-1, *x.shape[lb:]) for x in inputs]
out = op(*inputs)
flattened_out = op(*flattened_inputs)
if jop is None:
return
jout = jop(*inputs)[0]
flattened_jout = jop(*flattened_inputs)[0]
torch.testing.assert_close(out, flattened_out.reshape(*batch_size, *out.shape[lb:]))
for j, jf in zip(jout, flattened_jout):
torch.testing.assert_close(j, jf.reshape(*batch_size, *j.shape[lb:]))

def left_project_func(module, group):
sels = range(group.shape[0])

def left_project_func(module, group, batch_dim):
def func(matrix: torch.Tensor):
return module._left_project_autograd_fn(group, matrix[sels, ..., sels, :, :])
assert matrix.ndim == 2 * batch_dim + 3 # shape should be (*BD, f, *BD, g1, g2)
g = group.clone()
# Convert to single-batch-dim sparse gradient format
batch_size = matrix.shape[:batch_dim]
if batch_dim > 0:
d = reduce(lambda x, y: x * y, batch_size)
matrix = matrix.reshape(d, -1, d, *group.shape[-2:])
sels = range(matrix.shape[0])
matrix = matrix[sels, ..., sels, :, :]
g = group.reshape(d, *group.shape[-2:])
# Compute projected gradient matrix
ret = module._left_project_autograd_fn(g, matrix)
# Revert to multi-batch format if necessary
if batch_dim > 0:
ret = ret.reshape(*batch_size, *ret.shape[-2:])
return ret

return func


# This function checks that vmap(jacrevc) works for the `group_fns.name`, where
# name can be "exp" or "inv".
# Requires torch >= 2.0
# Compares the output of vmap(jacrev(log(fn(x)))) to jfn(x).
# For "inv" the output of vmap has to be left-projected,
# to make get a Riemannian jacobian.
def check_jacrev_unary(group_fns, dim, batch_size, name):
assert name in ["exp", "inv"]
if not hasattr(torch, "vmap"):
return

Expand All @@ -155,7 +215,14 @@ def f(t):
torch.testing.assert_close(jac_vmap, jac_analytic)


# This function checks that vmap(jacrevc) works for the `group_fns.name`, where
# name can be "compose" or "transform_from".
# Requires torch >= 2.0
# Compares the output of vmap(jacrev(log(fn(x)))) to jfn(x).
# For all group inputs, the output of vmap has to be left-projected,
# to make get a Riemannian jacobian.
def check_jacrev_binary(group_fns, batch_size, name):
assert name in ["compose", "transform_from"]
if not hasattr(torch, "vmap"):
return

Expand Down Expand Up @@ -187,3 +254,48 @@ def f(t1, t2):
for i in range(2):
jac_analytic = jlog[0] @ jtest[i] if name == "compose" else jtest[i]
torch.testing.assert_close(jacs_vmap[i], jac_analytic)


def _get_broadcast_size(bs1, bs2):
m = max(len(bs1), len(bs2))

def _full_dim(bs):
return bs if (len(bs) == m) else (1,) * (m - len(bs)) + bs

bs1_full = _full_dim(bs1)
bs2_full = _full_dim(bs2)

return tuple(max(a, b) for a, b in zip(bs1_full, bs2_full))


# flatten to a single batch dimension
def _expand_flat(tensor, broadcast_size, group_size):
return tensor.clone().expand(broadcast_size + group_size).reshape(-1, *group_size)


def check_binary_op_broadcasting(group_fns, op_name, group_size, bs1, bs2, dtype, rng):
assert op_name in ["compose", "transform_from"]
g1 = group_fns.rand(*bs1, generator=rng, dtype=dtype)
if op_name == "compose":
t2 = group_fns.rand(*bs2, generator=rng, dtype=dtype)
t2_size = group_size
else:
t2 = torch.randn(*bs2, 3, generator=rng, dtype=dtype)
t2_size = (3,)

# The following code does broadcasting manually, then we check that
# manual broadcast output is the same as the automatic broadcasting
broadcast_size = _get_broadcast_size(bs1, bs2)
t1_expand_flat = _expand_flat(g1, broadcast_size, group_size)
t2_expand_flat = _expand_flat(t2, broadcast_size, t2_size)

fn = getattr(group_fns, op_name)
jfn = getattr(group_fns, f"j{op_name}")
out = fn(g1, t2)
out_expand_flat = fn(t1_expand_flat, t2_expand_flat)
torch.testing.assert_close(out, out_expand_flat.reshape(broadcast_size + t2_size))

jout = jfn(g1, t2)[0]
jout_expand_flat = jfn(t1_expand_flat, t2_expand_flat)[0]
for j1, j2 in zip(jout, jout_expand_flat):
torch.testing.assert_close(j1, j2.reshape(broadcast_size + j1.shape[-2:]))
25 changes: 22 additions & 3 deletions tests/labs/lie/functional/test_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Sequence, Union

import pytest

import torch

from tests.decorators import run_if_labs
from .common import (
BATCH_SIZES_TO_TEST,
TEST_EPS,
check_binary_op_broadcasting,
check_lie_group_function,
check_jacrev_binary,
check_jacrev_unary,
Expand Down Expand Up @@ -49,12 +50,15 @@ def test_op(op_name, batch_size, dtype):
@run_if_labs()
@pytest.mark.parametrize("batch_size", BATCH_SIZES_TO_TEST)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_vee(batch_size: int, dtype: torch.dtype):
def test_vee(batch_size: Union[int, Sequence[int]], dtype: torch.dtype):
import theseus.labs.lie.functional.se3_impl as SE3

if isinstance(batch_size, int):
batch_size = (batch_size,)

rng = torch.Generator()
rng.manual_seed(0)
tangent_vector = torch.rand(batch_size, 6, dtype=dtype, generator=rng)
tangent_vector = torch.rand(*batch_size, 6, dtype=dtype, generator=rng)
matrix = SE3._hat_autograd_fn(tangent_vector)

# check analytic backward for the operator
Expand Down Expand Up @@ -86,3 +90,18 @@ def test_jacrev_binary(batch_size, name):
import theseus.labs.lie.functional as lieF

check_jacrev_binary(lieF.SE3, batch_size, name)


@run_if_labs()
@pytest.mark.parametrize("name", ["compose", "transform_from"])
def test_binary_op_broadcasting(name):
from theseus.labs.lie.functional import SE3

rng = torch.Generator()
rng.manual_seed(0)
batch_sizes = [(1,), (2,), (1, 2), (2, 1), (2, 2), (2, 2, 2), tuple()]
for bs1 in batch_sizes:
for bs2 in batch_sizes:
check_binary_op_broadcasting(
SE3, name, (3, 4), bs1, bs2, torch.float64, rng
)
24 changes: 21 additions & 3 deletions tests/labs/lie/functional/test_so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Sequence, Union

import pytest

import torch

from tests.decorators import run_if_labs
from .common import (
BATCH_SIZES_TO_TEST,
TEST_EPS,
check_binary_op_broadcasting,
check_lie_group_function,
check_jacrev_binary,
check_jacrev_unary,
Expand Down Expand Up @@ -50,12 +51,14 @@ def test_op(op_name, batch_size, dtype):
@run_if_labs()
@pytest.mark.parametrize("batch_size", BATCH_SIZES_TO_TEST)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_vee(batch_size: int, dtype: torch.dtype):
def test_vee(batch_size: Union[int, Sequence[int]], dtype: torch.dtype):
import theseus.labs.lie.functional.so3_impl as so3

if isinstance(batch_size, int):
batch_size = (batch_size,)
rng = torch.Generator()
rng.manual_seed(0)
tangent_vector = torch.rand(batch_size, 3, dtype=dtype, generator=rng)
tangent_vector = torch.rand(*batch_size, 3, dtype=dtype, generator=rng)
matrix = so3._hat_autograd_fn(tangent_vector)

# check analytic backward for the operator
Expand Down Expand Up @@ -87,3 +90,18 @@ def test_jacrev_binary(batch_size, name):
import theseus.labs.lie.functional as lieF

check_jacrev_binary(lieF.SO3, batch_size, name)


@run_if_labs()
@pytest.mark.parametrize("name", ["compose", "transform_from"])
def test_binary_op_broadcasting(name):
from theseus.labs.lie.functional import SO3

rng = torch.Generator()
rng.manual_seed(0)
batch_sizes = [(1,), (2,), (1, 2), (2, 1), (2, 2), (2, 2, 2), tuple()]
for bs1 in batch_sizes:
for bs2 in batch_sizes:
check_binary_op_broadcasting(
SO3, name, (3, 3), bs1, bs2, torch.float64, rng
)
4 changes: 3 additions & 1 deletion theseus/labs/lie/functional/se3_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,9 @@ def _jlog_impl_helper(
-1 / 6.0 - theta2 / 180.0,
(theta - sine) / (theta_nz * two_cosine_minus_two_nz),
)
e = (ret_ang.view(*size, 1, 3) @ ret_lin.view(*size, 3, 1)).view(*size)

e = ret_ang.view(*size, 1, 3) @ ret_lin.view(*size, 3, 1)
e = e.view(*size) if len(size) > 0 else e.squeeze()

ce_ret_ang = (c * e).view(*size, 1) * ret_ang
jac[..., :3, 3:] = ce_ret_ang.view(*size, 3, 1) * ret_ang.view(*size, 1, 3)
Expand Down
2 changes: 0 additions & 2 deletions theseus/labs/lie/functional/so3_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,8 +726,6 @@ def backward(cls, ctx, grad_output):
# Unit Quaternion to Rotation Matrix
# -----------------------------------------------------------------------------
def _quaternion_to_rotation_impl(quaternion: torch.Tensor) -> torch.Tensor:
if quaternion.ndim == 1:
quaternion = quaternion.unsqueeze(0)
check_unit_quaternion(quaternion)

quaternion = quaternion / torch.norm(quaternion, dim=-1, keepdim=True)
Expand Down

0 comments on commit aed9eb7

Please sign in to comment.