Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a differentiable sparse matrix vector product on top of our ops #392

Merged
merged 5 commits into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added autograd function for sparse matrix transpose vector product.
  • Loading branch information
luisenp committed Dec 8, 2022
commit ab9fcc77d5efe061945821a9b942b6dc2e68c34b
37 changes: 22 additions & 15 deletions tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_gather_from_rows_cols():
assert torch.allclose(res[i, j], matrix[i, rows[i, j], cols[i, j]])


def _check_sparse_mv(batch_size, num_rows, num_cols, fill, device):
def _check_sparse_mv_and_mtv(batch_size, num_rows, num_cols, fill, device):
A_col_ind, A_row_ptr, A_val, _ = thutils.random_sparse_matrix(
batch_size,
num_rows,
Expand All @@ -107,6 +107,7 @@ def _check_sparse_mv(batch_size, num_rows, num_cols, fill, device):
device,
)
b = torch.randn(batch_size, num_cols, device=device).double()
b2 = torch.randn(batch_size, num_rows, device=device).double()

# Check backward pass
if batch_size < 16:
Expand All @@ -115,29 +116,35 @@ def _check_sparse_mv(batch_size, num_rows, num_cols, fill, device):
torch.autograd.gradcheck(
thutils.sparse_mv, (num_cols, A_row_ptr, A_col_ind, A_val, b)
)
torch.autograd.gradcheck(
thutils.sparse_mtv, (num_cols, A_row_ptr, A_col_ind, A_val, b2)
)

# Check forward pass
out = thutils.sparse_mv(num_cols, A_row_ptr, A_col_ind, A_val, b)
Ab_bundle = [scipy.sparse.csr_matrix, thutils.sparse_mv, (num_rows, num_cols), b]
Atb_bundle = [scipy.sparse.csc_matrix, thutils.sparse_mtv, (num_cols, num_rows), b2]
for i in range(batch_size):
A_csr = scipy.sparse.csr_matrix(
(
A_val[i].detach().cpu().numpy(),
A_col_ind.cpu().numpy(),
A_row_ptr.cpu().numpy(),
),
(num_rows, num_cols),
)
expected_out = A_csr * b[i].detach().cpu().numpy()
diff = expected_out - out[i].detach().cpu().numpy()
assert np.linalg.norm(diff) < 1e-8
for sparse_constructor, sparse_op, shape, b_tensor in [Ab_bundle, Atb_bundle]:
out = sparse_op(num_cols, A_row_ptr, A_col_ind, A_val, b_tensor)
A_sparse = sparse_constructor(
(
A_val[i].detach().cpu().numpy(),
A_col_ind.cpu().numpy(),
A_row_ptr.cpu().numpy(),
),
shape,
)
expected_out = A_sparse * b_tensor[i].detach().cpu().numpy()
diff = expected_out - out[i].detach().cpu().numpy()
assert np.linalg.norm(diff) < 1e-8


@pytest.mark.parametrize("batch_size", [1, 4, 16])
@pytest.mark.parametrize("num_rows", [1, 32])
@pytest.mark.parametrize("num_cols", [1, 4, 32])
@pytest.mark.parametrize("fill", [0.1, 0.9])
def test_sparse_mv_cpu(batch_size, num_rows, num_cols, fill):
_check_sparse_mv(batch_size, num_rows, num_cols, fill, "cpu")
_check_sparse_mv_and_mtv(batch_size, num_rows, num_cols, fill, "cpu")


@pytest.mark.cudaext
Expand All @@ -146,4 +153,4 @@ def test_sparse_mv_cpu(batch_size, num_rows, num_cols, fill):
@pytest.mark.parametrize("num_cols", [1, 4, 32])
@pytest.mark.parametrize("fill", [0.1, 0.9])
def test_sparse_mv_cuda(batch_size, num_rows, num_cols, fill):
_check_sparse_mv(batch_size, num_rows, num_cols, fill, "cuda:0")
_check_sparse_mv_and_mtv(batch_size, num_rows, num_cols, fill, "cuda:0")
1 change: 1 addition & 0 deletions theseus/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
random_sparse_matrix,
random_sparse_binary_matrix,
sparse_mv,
sparse_mtv,
split_into_param_sizes,
tmat_vec,
)
Expand Down
87 changes: 69 additions & 18 deletions theseus/utils/sparse_matrix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, List, Tuple
from typing import Any, Callable, List, Tuple

import numpy as np
import torch
Expand Down Expand Up @@ -99,6 +99,47 @@ def tmat_vec(
return _tmat_vec_cpu(batch_size, num_cols, A_row_ptr, A_col_ind, A_val, v)


def _sparse_mat_vec_fwd_backend(
ctx: Any,
num_cols: int,
A_row_ptr: torch.Tensor,
A_col_ind: torch.Tensor,
A_val: torch.Tensor,
v: torch.Tensor,
op: Callable[
[int, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
torch.Tensor,
],
) -> torch.Tensor:
assert A_row_ptr.ndim == 1
assert A_col_ind.ndim == 1
assert A_val.ndim == 2
assert v.ndim == 2
ctx.save_for_backward(A_val, A_row_ptr, A_col_ind, v)
ctx.num_cols = num_cols
return op(A_val.shape[0], num_cols, A_row_ptr, A_col_ind, A_val, v)


def _sparse_mat_vec_bwd_backend(
ctx: Any, grad_output: torch.Tensor, is_tmat: bool
) -> Tuple[torch.Tensor, torch.Tensor]:
A_val, A_row_ptr, A_col_ind, v = ctx.saved_tensors
num_rows = len(A_row_ptr) - 1
A_grad = torch.zeros_like(A_val) # (batch_size, nnz)
v_grad = torch.zeros_like(v) # (batch_size, num_cols)
for row in range(num_rows):
start = A_row_ptr[row]
end = A_row_ptr[row + 1]
columns = A_col_ind[start:end]
if is_tmat:
A_grad[:, start:end] = v[:, row].view(-1, 1) * grad_output[:, columns]
v_grad[:, row] = (grad_output[:, columns] * A_val[:, start:end]).sum(dim=1)
else:
A_grad[:, start:end] = v[:, columns] * grad_output[:, row].view(-1, 1)
v_grad[:, columns] += grad_output[:, row].view(-1, 1) * A_val[:, start:end]
return A_grad, v_grad


class _SparseMvPAutograd(torch.autograd.Function):
@staticmethod
def forward( # type: ignore
Expand All @@ -109,34 +150,44 @@ def forward( # type: ignore
A_val: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
assert A_row_ptr.ndim == 1
assert A_col_ind.ndim == 1
assert A_val.ndim == 2
assert v.ndim == 2
ctx.save_for_backward(A_val, A_row_ptr, A_col_ind, v)
ctx.num_cols = num_cols
return mat_vec(A_val.shape[0], num_cols, A_row_ptr, A_col_ind, A_val, v)
return _sparse_mat_vec_fwd_backend(
ctx, num_cols, A_row_ptr, A_col_ind, A_val, v, mat_vec
)

@staticmethod
@torch.autograd.function.once_differentiable
def backward( # type: ignore
ctx: Any, grad_output: torch.Tensor
) -> Tuple[None, None, None, torch.Tensor, torch.Tensor]:
A_val, A_row_ptr, A_col_ind, v = ctx.saved_tensors
num_rows = len(A_row_ptr) - 1
A_grad = torch.zeros_like(A_val) # (batch_size, nnz)
v_grad = torch.zeros_like(v) # (batch_size, num_cols)
for row in range(num_rows):
start = A_row_ptr[row]
end = A_row_ptr[row + 1]
columns = A_col_ind[start:end]
A_grad[:, start:end] = v[:, columns] * grad_output[:, row].view(-1, 1)
v_grad[:, columns] += grad_output[:, row].view(-1, 1) * A_val[:, start:end]
A_grad, v_grad = _sparse_mat_vec_bwd_backend(ctx, grad_output, False)
return None, None, None, A_grad, v_grad


class _SparseMtvPAutograd(torch.autograd.Function):
@staticmethod
def forward( # type: ignore
ctx: Any,
num_cols: int,
A_row_ptr: torch.Tensor,
A_col_ind: torch.Tensor,
A_val: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
return _sparse_mat_vec_fwd_backend(
ctx, num_cols, A_row_ptr, A_col_ind, A_val, v, tmat_vec
)

@staticmethod
@torch.autograd.function.once_differentiable
def backward( # type: ignore
ctx: Any, grad_output: torch.Tensor
) -> Tuple[None, None, None, torch.Tensor, torch.Tensor]:
A_grad, v_grad = _sparse_mat_vec_bwd_backend(ctx, grad_output, True)
return None, None, None, A_grad, v_grad


sparse_mv = _SparseMvPAutograd.apply
sparse_mtv = _SparseMtvPAutograd.apply


def random_sparse_binary_matrix(
Expand Down