Skip to content

Commit

Permalink
Add a differentiable sparse matrix vector product on top of our ops (#…
Browse files Browse the repository at this point in the history
…392)

* Add autograd function for sparse matrix vector product.

* Add wrapper for sparse_mv in SparseLinearization.

* Added autograd function for sparse matrix transpose vector product.

* Add wrapper for sparse_mtv in SparseLinearization to make differentiable Atb.

* Fix dtype index bug.
  • Loading branch information
luisenp authored Dec 8, 2022
1 parent f2748fe commit 7dca714
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 6 deletions.
3 changes: 3 additions & 0 deletions tests/optimizer/nonlinear/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,9 @@ def _ata_impl(self) -> torch.Tensor:
def _atb_impl(self) -> torch.Tensor:
return self._Atb

def Av(self, v):
pass

class MockCostFunction(th.CostFunction):
def __init__(self, optim_vars, cost_weight):
super().__init__(cost_weight)
Expand Down
14 changes: 14 additions & 0 deletions tests/optimizer/test_sparse_linearization.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,17 @@ def test_sparse_linearization():

for i in range(batch_size):
assert b[i].isclose(linearization.b[i]).all()

# Test Atb result
atb_expected = A.transpose(1, 2).bmm(b.unsqueeze(2))
atb_out = linearization.Atb
torch.testing.assert_close(atb_expected, atb_out)

# Test Av() with a random v
rng = torch.Generator()
rng.manual_seed(1009)
for _ in range(20):
v = torch.randn(A.shape[0], A.shape[2], 1)
av_expected = A.bmm(v).squeeze(2)
av_out = linearization.Av(v.squeeze(2))
torch.testing.assert_close(av_expected, av_out)
61 changes: 61 additions & 0 deletions tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import pytest # noqa: F401
import scipy.sparse
import torch
import torch.nn as nn

Expand Down Expand Up @@ -93,3 +94,63 @@ def test_gather_from_rows_cols():
for i in range(batch_size):
for j in range(num_points):
assert torch.allclose(res[i, j], matrix[i, rows[i, j], cols[i, j]])


def _check_sparse_mv_and_mtv(batch_size, num_rows, num_cols, fill, device):
A_col_ind, A_row_ptr, A_val, _ = thutils.random_sparse_matrix(
batch_size,
num_rows,
num_cols,
fill,
min(num_cols, 2),
torch.Generator(device),
device,
)
b = torch.randn(batch_size, num_cols, device=device).double()
b2 = torch.randn(batch_size, num_rows, device=device).double()

# Check backward pass
if batch_size < 16:
A_val.requires_grad = True
b.requires_grad = True
torch.autograd.gradcheck(
thutils.sparse_mv, (num_cols, A_row_ptr, A_col_ind, A_val, b)
)
torch.autograd.gradcheck(
thutils.sparse_mtv, (num_cols, A_row_ptr, A_col_ind, A_val, b2)
)

# Check forward pass
Ab_bundle = [scipy.sparse.csr_matrix, thutils.sparse_mv, (num_rows, num_cols), b]
Atb_bundle = [scipy.sparse.csc_matrix, thutils.sparse_mtv, (num_cols, num_rows), b2]
for i in range(batch_size):
for sparse_constructor, sparse_op, shape, b_tensor in [Ab_bundle, Atb_bundle]:
out = sparse_op(num_cols, A_row_ptr, A_col_ind, A_val, b_tensor)
A_sparse = sparse_constructor(
(
A_val[i].detach().cpu().numpy(),
A_col_ind.cpu().numpy(),
A_row_ptr.cpu().numpy(),
),
shape,
)
expected_out = A_sparse * b_tensor[i].detach().cpu().numpy()
diff = expected_out - out[i].detach().cpu().numpy()
assert np.linalg.norm(diff) < 1e-8


@pytest.mark.parametrize("batch_size", [1, 4, 16])
@pytest.mark.parametrize("num_rows", [1, 32])
@pytest.mark.parametrize("num_cols", [1, 4, 32])
@pytest.mark.parametrize("fill", [0.1, 0.9])
def test_sparse_mv_cpu(batch_size, num_rows, num_cols, fill):
_check_sparse_mv_and_mtv(batch_size, num_rows, num_cols, fill, "cpu")


@pytest.mark.cudaext
@pytest.mark.parametrize("batch_size", [1, 4, 16])
@pytest.mark.parametrize("num_rows", [1, 32])
@pytest.mark.parametrize("num_cols", [1, 4, 32])
@pytest.mark.parametrize("fill", [0.1, 0.9])
def test_sparse_mv_cuda(batch_size, num_rows, num_cols, fill):
_check_sparse_mv_and_mtv(batch_size, num_rows, num_cols, fill, "cuda:0")
3 changes: 3 additions & 0 deletions theseus/optimizer/dense_linearization.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,6 @@ def _ata_impl(self) -> torch.Tensor:

def _atb_impl(self) -> torch.Tensor:
return self._Atb

def Av(self, v: torch.Tensor) -> torch.Tensor:
return self.A.bmm(v.unsqueeze(2)).squeeze(2)
5 changes: 5 additions & 0 deletions theseus/optimizer/linearization.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,8 @@ def AtA(self) -> torch.Tensor:
@property
def Atb(self) -> torch.Tensor:
return self._atb_impl()

# Returns self.A @ v
@abc.abstractmethod
def Av(self, v: torch.Tensor) -> torch.Tensor:
pass
14 changes: 11 additions & 3 deletions theseus/optimizer/sparse_linearization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch

from theseus.core import Objective
from theseus.utils.sparse_matrix_utils import tmat_vec
from theseus.utils.sparse_matrix_utils import sparse_mv, sparse_mtv

from .linear_system import SparseStructure
from .linearization import Linearization
Expand Down Expand Up @@ -152,12 +152,20 @@ def _atb_impl(self) -> torch.Tensor:
A_col_ind = A_row_ptr.new_tensor(self.A_col_ind)

# unsqueeze at the end for consistency with DenseLinearization
self._Atb = tmat_vec(
self.objective.batch_size,
self._Atb = sparse_mtv(
self.num_cols,
A_row_ptr,
A_col_ind,
self.A_val.double(),
self.b.double(),
).unsqueeze(2)
return self._Atb.to(dtype=self.A_val.dtype)

def Av(self, v: torch.Tensor) -> torch.Tensor:
A_row_ptr = torch.tensor(self.A_row_ptr, dtype=torch.int32).to(
self.objective.device
)
A_col_ind = A_row_ptr.new_tensor(self.A_col_ind)
return sparse_mv(
self.num_cols, A_row_ptr, A_col_ind, self.A_val.double(), v.double()
).to(v.dtype)
2 changes: 2 additions & 0 deletions theseus/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
mat_vec,
random_sparse_matrix,
random_sparse_binary_matrix,
sparse_mv,
sparse_mtv,
split_into_param_sizes,
tmat_vec,
)
Expand Down
99 changes: 96 additions & 3 deletions theseus/utils/sparse_matrix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Tuple
from typing import Any, Callable, List, Tuple

import numpy as np
import torch
from scipy.sparse import csc_matrix, csr_matrix, lil_matrix

from theseus.constants import DeviceType


def _mat_vec_cpu(
batch_size: int,
Expand Down Expand Up @@ -97,6 +99,97 @@ def tmat_vec(
return _tmat_vec_cpu(batch_size, num_cols, A_row_ptr, A_col_ind, A_val, v)


def _sparse_mat_vec_fwd_backend(
ctx: Any,
num_cols: int,
A_row_ptr: torch.Tensor,
A_col_ind: torch.Tensor,
A_val: torch.Tensor,
v: torch.Tensor,
op: Callable[
[int, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
torch.Tensor,
],
) -> torch.Tensor:
assert A_row_ptr.ndim == 1
assert A_col_ind.ndim == 1
assert A_val.ndim == 2
assert v.ndim == 2
ctx.save_for_backward(A_val, A_row_ptr, A_col_ind, v)
ctx.num_cols = num_cols
return op(A_val.shape[0], num_cols, A_row_ptr, A_col_ind, A_val, v)


def _sparse_mat_vec_bwd_backend(
ctx: Any, grad_output: torch.Tensor, is_tmat: bool
) -> Tuple[torch.Tensor, torch.Tensor]:
A_val, A_row_ptr, A_col_ind, v = ctx.saved_tensors
num_rows = len(A_row_ptr) - 1
A_grad = torch.zeros_like(A_val) # (batch_size, nnz)
v_grad = torch.zeros_like(v) # (batch_size, num_cols)
for row in range(num_rows):
start = A_row_ptr[row]
end = A_row_ptr[row + 1]
columns = A_col_ind[start:end].long()
if is_tmat:
A_grad[:, start:end] = v[:, row].view(-1, 1) * grad_output[:, columns]
v_grad[:, row] = (grad_output[:, columns] * A_val[:, start:end]).sum(dim=1)
else:
A_grad[:, start:end] = v[:, columns] * grad_output[:, row].view(-1, 1)
v_grad[:, columns] += grad_output[:, row].view(-1, 1) * A_val[:, start:end]
return A_grad, v_grad


class _SparseMvPAutograd(torch.autograd.Function):
@staticmethod
def forward( # type: ignore
ctx: Any,
num_cols: int,
A_row_ptr: torch.Tensor,
A_col_ind: torch.Tensor,
A_val: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
return _sparse_mat_vec_fwd_backend(
ctx, num_cols, A_row_ptr, A_col_ind, A_val, v, mat_vec
)

@staticmethod
@torch.autograd.function.once_differentiable
def backward( # type: ignore
ctx: Any, grad_output: torch.Tensor
) -> Tuple[None, None, None, torch.Tensor, torch.Tensor]:
A_grad, v_grad = _sparse_mat_vec_bwd_backend(ctx, grad_output, False)
return None, None, None, A_grad, v_grad


class _SparseMtvPAutograd(torch.autograd.Function):
@staticmethod
def forward( # type: ignore
ctx: Any,
num_cols: int,
A_row_ptr: torch.Tensor,
A_col_ind: torch.Tensor,
A_val: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
return _sparse_mat_vec_fwd_backend(
ctx, num_cols, A_row_ptr, A_col_ind, A_val, v, tmat_vec
)

@staticmethod
@torch.autograd.function.once_differentiable
def backward( # type: ignore
ctx: Any, grad_output: torch.Tensor
) -> Tuple[None, None, None, torch.Tensor, torch.Tensor]:
A_grad, v_grad = _sparse_mat_vec_bwd_backend(ctx, grad_output, True)
return None, None, None, A_grad, v_grad


sparse_mv = _SparseMvPAutograd.apply
sparse_mtv = _SparseMtvPAutograd.apply


def random_sparse_binary_matrix(
num_rows: int,
num_cols: int,
Expand All @@ -106,7 +199,7 @@ def random_sparse_binary_matrix(
) -> csr_matrix:
retv = lil_matrix((num_rows, num_cols))

if min_entries_per_col > 0:
if num_rows > 1 and min_entries_per_col > 0:
min_entries_per_col = min(num_rows, min_entries_per_col)
rows_array = torch.arange(num_rows, device=rng.device)
rows_array_f = rows_array.to(dtype=torch.float)
Expand Down Expand Up @@ -138,7 +231,7 @@ def random_sparse_matrix(
fill: float,
min_entries_per_col: int,
rng: torch.Generator,
device: torch.device,
device: DeviceType,
int_dtype: torch.dtype = torch.int64,
float_dtype: torch.dtype = torch.double,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down

0 comments on commit 7dca714

Please sign in to comment.