Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a differentiable sparse matrix vector product on top of our ops #392

Merged
merged 5 commits into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add autograd function for sparse matrix vector product.
  • Loading branch information
luisenp committed Dec 8, 2022
commit 5f22784c112f42cd71ded1e8c53139155b1ee8c0
54 changes: 54 additions & 0 deletions tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import pytest # noqa: F401
import scipy.sparse
import torch
import torch.nn as nn

Expand Down Expand Up @@ -93,3 +94,56 @@ def test_gather_from_rows_cols():
for i in range(batch_size):
for j in range(num_points):
assert torch.allclose(res[i, j], matrix[i, rows[i, j], cols[i, j]])


def _check_sparse_mv(batch_size, num_rows, num_cols, fill, device):
A_col_ind, A_row_ptr, A_val, _ = thutils.random_sparse_matrix(
batch_size,
num_rows,
num_cols,
fill,
min(num_cols, 2),
torch.Generator(device),
device,
)
b = torch.randn(batch_size, num_cols, device=device).double()

# Check backward pass
if batch_size < 16:
A_val.requires_grad = True
b.requires_grad = True
torch.autograd.gradcheck(
thutils.sparse_mv, (num_cols, A_val, A_row_ptr, A_col_ind, b)
)

# Check forward pass
out = thutils.sparse_mv(num_cols, A_val, A_row_ptr, A_col_ind, b)
for i in range(batch_size):
A_csr = scipy.sparse.csr_matrix(
(
A_val[i].detach().cpu().numpy(),
A_col_ind.cpu().numpy(),
A_row_ptr.cpu().numpy(),
),
(num_rows, num_cols),
)
expected_out = A_csr * b[i].detach().cpu().numpy()
diff = expected_out - out[i].detach().cpu().numpy()
assert np.linalg.norm(diff) < 1e-8


@pytest.mark.parametrize("batch_size", [1, 4, 16])
@pytest.mark.parametrize("num_rows", [1, 32])
@pytest.mark.parametrize("num_cols", [1, 4, 32])
@pytest.mark.parametrize("fill", [0.1, 0.9])
def test_sparse_mv_cpu(batch_size, num_rows, num_cols, fill):
_check_sparse_mv(batch_size, num_rows, num_cols, fill, "cpu")


@pytest.mark.cudaext
@pytest.mark.parametrize("batch_size", [1, 4, 16])
@pytest.mark.parametrize("num_rows", [1, 32])
@pytest.mark.parametrize("num_cols", [1, 4, 32])
@pytest.mark.parametrize("fill", [0.1, 0.9])
def test_sparse_mv_cuda(batch_size, num_rows, num_cols, fill):
_check_sparse_mv(batch_size, num_rows, num_cols, fill, "cuda:0")
1 change: 1 addition & 0 deletions theseus/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
mat_vec,
random_sparse_matrix,
random_sparse_binary_matrix,
sparse_mv,
split_into_param_sizes,
tmat_vec,
)
Expand Down
50 changes: 47 additions & 3 deletions theseus/utils/sparse_matrix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Tuple
from typing import Any, List, Tuple

import numpy as np
import torch
from scipy.sparse import csc_matrix, csr_matrix, lil_matrix

from theseus.constants import DeviceType


def _mat_vec_cpu(
batch_size: int,
Expand Down Expand Up @@ -97,6 +99,48 @@ def tmat_vec(
return _tmat_vec_cpu(batch_size, num_cols, A_row_ptr, A_col_ind, A_val, v)


class _SparseMvPAutograd(torch.autograd.Function):
@staticmethod
def forward( # type: ignore
ctx: Any,
num_cols: int,
A_val: torch.Tensor,
A_row_ptr: torch.Tensor,
A_col_ind: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
assert (
A_row_ptr.ndim == 1
and A_col_ind.ndim == 1
and A_val.ndim == 2
and v.ndim == 2
)
ctx.save_for_backward(A_val, A_row_ptr, A_col_ind, v)
ctx.num_cols = num_cols
return mat_vec(A_val.shape[0], num_cols, A_row_ptr, A_col_ind, A_val, v)

@staticmethod
@torch.autograd.function.once_differentiable
def backward( # type: ignore
ctx: Any, grad_output: torch.Tensor
) -> Tuple[None, torch.Tensor, None, None, torch.Tensor]:
A_val, A_row_ptr, A_col_ind, v = ctx.saved_tensors
num_rows = len(A_row_ptr) - 1
A_grad = torch.zeros_like(A_val) # (batch_size, nnz)
v_grad = torch.zeros_like(v) # (batch_size, num_cols)
for row in range(num_rows):
start = A_row_ptr[row]
end = A_row_ptr[row + 1]
columns = A_col_ind[start:end]
A_grad[:, start:end] = v[:, columns] * grad_output[:, row].view(-1, 1)
v_grad[:, columns] += grad_output[:, row].view(-1, 1) * A_val[:, start:end]

return None, A_grad, None, None, v_grad


sparse_mv = _SparseMvPAutograd.apply


def random_sparse_binary_matrix(
num_rows: int,
num_cols: int,
Expand All @@ -106,7 +150,7 @@ def random_sparse_binary_matrix(
) -> csr_matrix:
retv = lil_matrix((num_rows, num_cols))

if min_entries_per_col > 0:
if num_rows > 1 and min_entries_per_col > 0:
min_entries_per_col = min(num_rows, min_entries_per_col)
rows_array = torch.arange(num_rows, device=rng.device)
rows_array_f = rows_array.to(dtype=torch.float)
Expand Down Expand Up @@ -138,7 +182,7 @@ def random_sparse_matrix(
fill: float,
min_entries_per_col: int,
rng: torch.Generator,
device: torch.device,
device: DeviceType,
int_dtype: torch.dtype = torch.int64,
float_dtype: torch.dtype = torch.double,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down