From 146e0bf7428716975d95787a947342529ab706c9 Mon Sep 17 00:00:00 2001 From: Luis Pineda <4759586+luisenp@users.noreply.github.com> Date: Thu, 1 Dec 2022 13:27:05 -0800 Subject: [PATCH] Add wrapper for sparse_mv in SparseLinearization. --- tests/optimizer/nonlinear/common.py | 3 +++ tests/optimizer/test_sparse_linearization.py | 8 ++++++++ tests/utils/test_utils.py | 4 ++-- theseus/optimizer/dense_linearization.py | 3 +++ theseus/optimizer/linearization.py | 5 +++++ theseus/optimizer/sparse_linearization.py | 15 ++++++++++++++- theseus/utils/sparse_matrix_utils.py | 16 +++++++--------- 7 files changed, 42 insertions(+), 12 deletions(-) diff --git a/tests/optimizer/nonlinear/common.py b/tests/optimizer/nonlinear/common.py index 25b4902c0..d08bff11b 100644 --- a/tests/optimizer/nonlinear/common.py +++ b/tests/optimizer/nonlinear/common.py @@ -234,6 +234,9 @@ def _ata_impl(self) -> torch.Tensor: def _atb_impl(self) -> torch.Tensor: return self._Atb + def Av(self, v): + pass + class MockCostFunction(th.CostFunction): def __init__(self, optim_vars, cost_weight): super().__init__(cost_weight) diff --git a/tests/optimizer/test_sparse_linearization.py b/tests/optimizer/test_sparse_linearization.py index e9583b4c3..e9451d24f 100644 --- a/tests/optimizer/test_sparse_linearization.py +++ b/tests/optimizer/test_sparse_linearization.py @@ -29,3 +29,11 @@ def test_sparse_linearization(): for i in range(batch_size): assert b[i].isclose(linearization.b[i]).all() + + rng = torch.Generator() + rng.manual_seed(1009) + for _ in range(20): + v = torch.randn(A.shape[0], A.shape[2], 1) + av_expected = A.bmm(v).squeeze(2) + av_out = linearization.Av(v.squeeze(2)) + torch.testing.assert_close(av_expected, av_out) diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 772cfab74..a5b8524bb 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -113,11 +113,11 @@ def _check_sparse_mv(batch_size, num_rows, num_cols, fill, device): A_val.requires_grad = True b.requires_grad = True torch.autograd.gradcheck( - thutils.sparse_mv, (num_cols, A_val, A_row_ptr, A_col_ind, b) + thutils.sparse_mv, (num_cols, A_row_ptr, A_col_ind, A_val, b) ) # Check forward pass - out = thutils.sparse_mv(num_cols, A_val, A_row_ptr, A_col_ind, b) + out = thutils.sparse_mv(num_cols, A_row_ptr, A_col_ind, A_val, b) for i in range(batch_size): A_csr = scipy.sparse.csr_matrix( ( diff --git a/theseus/optimizer/dense_linearization.py b/theseus/optimizer/dense_linearization.py index 512853b57..1b4b3c253 100644 --- a/theseus/optimizer/dense_linearization.py +++ b/theseus/optimizer/dense_linearization.py @@ -69,3 +69,6 @@ def _ata_impl(self) -> torch.Tensor: def _atb_impl(self) -> torch.Tensor: return self._Atb + + def Av(self, v: torch.Tensor) -> torch.Tensor: + return self.A.bmm(v.unsqueeze(2)).squeeze(2) diff --git a/theseus/optimizer/linearization.py b/theseus/optimizer/linearization.py index 9d84bd710..223d251a4 100644 --- a/theseus/optimizer/linearization.py +++ b/theseus/optimizer/linearization.py @@ -75,3 +75,8 @@ def AtA(self) -> torch.Tensor: @property def Atb(self) -> torch.Tensor: return self._atb_impl() + + # Returns self.A @ v + @abc.abstractmethod + def Av(self, v: torch.Tensor) -> torch.Tensor: + pass diff --git a/theseus/optimizer/sparse_linearization.py b/theseus/optimizer/sparse_linearization.py index ef9e23626..4b86e68fb 100644 --- a/theseus/optimizer/sparse_linearization.py +++ b/theseus/optimizer/sparse_linearization.py @@ -9,7 +9,7 @@ import torch from theseus.core import Objective -from theseus.utils.sparse_matrix_utils import tmat_vec +from theseus.utils.sparse_matrix_utils import sparse_mv, tmat_vec from .linear_system import SparseStructure from .linearization import Linearization @@ -145,6 +145,10 @@ def _ata_impl(self) -> torch.Tensor: raise NotImplementedError("AtA is not yet implemented for SparseLinearization.") def _atb_impl(self) -> torch.Tensor: + if torch.is_grad_enabled(): + raise NotImplementedError( + "Atb is not differentiable for SparseLinearization." + ) if self._Atb is None: A_row_ptr = torch.tensor(self.A_row_ptr, dtype=torch.int32).to( self.objective.device @@ -161,3 +165,12 @@ def _atb_impl(self) -> torch.Tensor: self.b.double(), ).unsqueeze(2) return self._Atb.to(dtype=self.A_val.dtype) + + def Av(self, v: torch.Tensor) -> torch.Tensor: + A_row_ptr = torch.tensor(self.A_row_ptr, dtype=torch.int32).to( + self.objective.device + ) + A_col_ind = A_row_ptr.new_tensor(self.A_col_ind) + return sparse_mv( + self.num_cols, A_row_ptr, A_col_ind, self.A_val.double(), v.double() + ).to(v.dtype) diff --git a/theseus/utils/sparse_matrix_utils.py b/theseus/utils/sparse_matrix_utils.py index 49774ffb5..528b35e39 100644 --- a/theseus/utils/sparse_matrix_utils.py +++ b/theseus/utils/sparse_matrix_utils.py @@ -104,17 +104,15 @@ class _SparseMvPAutograd(torch.autograd.Function): def forward( # type: ignore ctx: Any, num_cols: int, - A_val: torch.Tensor, A_row_ptr: torch.Tensor, A_col_ind: torch.Tensor, + A_val: torch.Tensor, v: torch.Tensor, ) -> torch.Tensor: - assert ( - A_row_ptr.ndim == 1 - and A_col_ind.ndim == 1 - and A_val.ndim == 2 - and v.ndim == 2 - ) + assert A_row_ptr.ndim == 1 + assert A_col_ind.ndim == 1 + assert A_val.ndim == 2 + assert v.ndim == 2 ctx.save_for_backward(A_val, A_row_ptr, A_col_ind, v) ctx.num_cols = num_cols return mat_vec(A_val.shape[0], num_cols, A_row_ptr, A_col_ind, A_val, v) @@ -123,7 +121,7 @@ def forward( # type: ignore @torch.autograd.function.once_differentiable def backward( # type: ignore ctx: Any, grad_output: torch.Tensor - ) -> Tuple[None, torch.Tensor, None, None, torch.Tensor]: + ) -> Tuple[None, None, None, torch.Tensor, torch.Tensor]: A_val, A_row_ptr, A_col_ind, v = ctx.saved_tensors num_rows = len(A_row_ptr) - 1 A_grad = torch.zeros_like(A_val) # (batch_size, nnz) @@ -135,7 +133,7 @@ def backward( # type: ignore A_grad[:, start:end] = v[:, columns] * grad_output[:, row].view(-1, 1) v_grad[:, columns] += grad_output[:, row].view(-1, 1) * A_val[:, start:end] - return None, A_grad, None, None, v_grad + return None, None, None, A_grad, v_grad sparse_mv = _SparseMvPAutograd.apply