Skip to content

Commit

Permalink
Add wrapper for sparse_mv in SparseLinearization.
Browse files Browse the repository at this point in the history
  • Loading branch information
luisenp committed Dec 1, 2022
1 parent 5a7a5f3 commit 8df7ef4
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 12 deletions.
3 changes: 3 additions & 0 deletions tests/optimizer/nonlinear/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,9 @@ def _ata_impl(self) -> torch.Tensor:
def _atb_impl(self) -> torch.Tensor:
return self._Atb

def Av(self, v):
pass

class MockCostFunction(th.CostFunction):
def __init__(self, optim_vars, cost_weight):
super().__init__(cost_weight)
Expand Down
8 changes: 8 additions & 0 deletions tests/optimizer/test_sparse_linearization.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,11 @@ def test_sparse_linearization():

for i in range(batch_size):
assert b[i].isclose(linearization.b[i]).all()

rng = torch.Generator()
rng.manual_seed(1009)
for _ in range(20):
v = torch.randn(A.shape[0], A.shape[2], 1)
av_expected = A.bmm(v).squeeze(2)
av_out = linearization.Av(v.squeeze(2))
torch.testing.assert_close(av_expected, av_out)
4 changes: 2 additions & 2 deletions tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,11 @@ def _check_sparse_mv(batch_size, num_rows, num_cols, fill, device):
A_val.requires_grad = True
b.requires_grad = True
torch.autograd.gradcheck(
thutils.sparse_mv, (num_cols, A_val, A_row_ptr, A_col_ind, b)
thutils.sparse_mv, (num_cols, A_row_ptr, A_col_ind, A_val, b)
)

# Check forward pass
out = thutils.sparse_mv(num_cols, A_val, A_row_ptr, A_col_ind, b)
out = thutils.sparse_mv(num_cols, A_row_ptr, A_col_ind, A_val, b)
for i in range(batch_size):
A_csr = scipy.sparse.csr_matrix(
(
Expand Down
3 changes: 3 additions & 0 deletions theseus/optimizer/dense_linearization.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,6 @@ def _ata_impl(self) -> torch.Tensor:

def _atb_impl(self) -> torch.Tensor:
return self._Atb

def Av(self, v: torch.Tensor) -> torch.Tensor:
return self.A.bmm(v.unsqueeze(2)).squeeze(2)
5 changes: 5 additions & 0 deletions theseus/optimizer/linearization.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,8 @@ def AtA(self) -> torch.Tensor:
@property
def Atb(self) -> torch.Tensor:
return self._atb_impl()

# Returns self.A @ v
@abc.abstractmethod
def Av(self, v: torch.Tensor) -> torch.Tensor:
pass
15 changes: 14 additions & 1 deletion theseus/optimizer/sparse_linearization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch

from theseus.core import Objective
from theseus.utils.sparse_matrix_utils import tmat_vec
from theseus.utils.sparse_matrix_utils import sparse_mv, tmat_vec

from .linear_system import SparseStructure
from .linearization import Linearization
Expand Down Expand Up @@ -145,6 +145,10 @@ def _ata_impl(self) -> torch.Tensor:
raise NotImplementedError("AtA is not yet implemented for SparseLinearization.")

def _atb_impl(self) -> torch.Tensor:
if torch.is_grad_enabled():
raise NotImplementedError(
"Atb is not differentiable for SparseLinearization."
)
if self._Atb is None:
A_row_ptr = torch.tensor(self.A_row_ptr, dtype=torch.int32).to(
self.objective.device
Expand All @@ -161,3 +165,12 @@ def _atb_impl(self) -> torch.Tensor:
self.b.double(),
).unsqueeze(2)
return self._Atb.to(dtype=self.A_val.dtype)

def Av(self, v: torch.Tensor) -> torch.Tensor:
A_row_ptr = torch.tensor(self.A_row_ptr, dtype=torch.int32).to(
self.objective.device
)
A_col_ind = A_row_ptr.new_tensor(self.A_col_ind)
return sparse_mv(
self.num_cols, A_row_ptr, A_col_ind, self.A_val.double(), v.double()
).to(v.dtype)
16 changes: 7 additions & 9 deletions theseus/utils/sparse_matrix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,15 @@ class _SparseMvPAutograd(torch.autograd.Function):
def forward( # type: ignore
ctx: Any,
num_cols: int,
A_val: torch.Tensor,
A_row_ptr: torch.Tensor,
A_col_ind: torch.Tensor,
A_val: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
assert (
A_row_ptr.ndim == 1
and A_col_ind.ndim == 1
and A_val.ndim == 2
and v.ndim == 2
)
assert A_row_ptr.ndim == 1
assert A_col_ind.ndim == 1
assert A_val.ndim == 2
assert v.ndim == 2
ctx.save_for_backward(A_val, A_row_ptr, A_col_ind, v)
ctx.num_cols = num_cols
return mat_vec(A_val.shape[0], num_cols, A_row_ptr, A_col_ind, A_val, v)
Expand All @@ -123,7 +121,7 @@ def forward( # type: ignore
@torch.autograd.function.once_differentiable
def backward( # type: ignore
ctx: Any, grad_output: torch.Tensor
) -> Tuple[None, torch.Tensor, None, None, torch.Tensor]:
) -> Tuple[None, None, None, torch.Tensor, torch.Tensor]:
A_val, A_row_ptr, A_col_ind, v = ctx.saved_tensors
num_rows = len(A_row_ptr) - 1
A_grad = torch.zeros_like(A_val) # (batch_size, nnz)
Expand All @@ -135,7 +133,7 @@ def backward( # type: ignore
A_grad[:, start:end] = v[:, columns] * grad_output[:, row].view(-1, 1)
v_grad[:, columns] += grad_output[:, row].view(-1, 1) * A_val[:, start:end]

return None, A_grad, None, None, v_grad
return None, None, None, A_grad, v_grad


sparse_mv = _SparseMvPAutograd.apply
Expand Down

0 comments on commit 8df7ef4

Please sign in to comment.