Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleaned up sparse solvers code #386

Merged
merged 5 commits into from
Nov 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions tests/extlib/test_baspacho.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


from tests.extlib.common import run_if_baspacho
from theseus.utils import random_sparse_binary_matrix, split_into_param_sizes
from theseus.utils import random_sparse_matrix, split_into_param_sizes


def check_baspacho(
Expand All @@ -35,17 +35,11 @@ def check_baspacho(

from theseus.extlib.baspacho_solver import SymbolicDecomposition

A_skel = random_sparse_binary_matrix(
num_rows, num_cols, fill, min_entries_per_col=1, rng=rng
A_col_ind, A_row_ptr, A_val, A_skel = random_sparse_matrix(
batch_size, num_rows, num_cols, fill, 1, rng, dev
)
A_num_cols = num_cols
A_rowPtr = torch.tensor(A_skel.indptr, dtype=torch.int64).to(dev)
A_colInd = torch.tensor(A_skel.indices, dtype=torch.int64).to(dev)
A_num_rows = A_rowPtr.size(0) - 1
A_nnz = A_colInd.size(0)
A_val = torch.rand(
(batch_size, A_nnz), device=dev, dtype=torch.double, generator=rng
)
A_num_rows = A_row_ptr.size(0) - 1
b = torch.rand(
(batch_size, A_num_rows), device=dev, dtype=torch.double, generator=rng
)
Expand All @@ -65,7 +59,7 @@ def check_baspacho(

A_csr = [
csr_matrix(
(A_val[i].cpu(), A_colInd.cpu(), A_rowPtr.cpu()), (A_num_rows, A_num_cols)
(A_val[i].cpu(), A_col_ind.cpu(), A_row_ptr.cpu()), (A_num_rows, A_num_cols)
)
for i in range(batch_size)
]
Expand All @@ -86,7 +80,7 @@ def check_baspacho(
)
f = s.create_numeric_decomposition(batch_size)

f.add_MtM(A_val, A_rowPtr, A_colInd)
f.add_MtM(A_val, A_row_ptr, A_col_ind)
beta = 0.01 * torch.rand(batch_size, device=dev, dtype=torch.double, generator=rng)
alpha = torch.rand(batch_size, device=dev, dtype=torch.double, generator=rng)
f.damp(alpha, beta)
Expand Down
24 changes: 10 additions & 14 deletions tests/extlib/test_cusolver_lu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch # needed for import of Torch C++ extensions to work
from scipy.sparse import csr_matrix

from theseus.utils import random_sparse_binary_matrix
from theseus.utils import random_sparse_matrix


# ideally we would like to support batch_size <= init_batch_size, but
Expand All @@ -25,20 +25,17 @@ def check_lu_solver(

rng = torch.Generator()
rng.manual_seed(0)
A_skel = random_sparse_binary_matrix(
num_rows, num_cols, fill, min_entries_per_col=3, rng=rng
A_col_ind, A_row_ptr, A_val, _ = random_sparse_matrix(
batch_size, num_rows, num_cols, fill, 3, rng, "cuda:0"
)
A_num_cols = num_cols
A_rowPtr = torch.tensor(A_skel.indptr, dtype=torch.int).cuda()
A_colInd = torch.tensor(A_skel.indices, dtype=torch.int).cuda()
A_num_rows = A_rowPtr.size(0) - 1
A_nnz = A_colInd.size(0)
A_val = torch.rand((batch_size, A_nnz), dtype=torch.double).cuda()
A_num_rows = A_row_ptr.size(0) - 1

b = torch.rand((batch_size, A_num_rows), dtype=torch.double).cuda()

A_csr = [
csr_matrix(
(A_val[i].cpu(), A_colInd.cpu(), A_rowPtr.cpu()), (A_num_rows, A_num_cols)
(A_val[i].cpu(), A_col_ind.cpu(), A_row_ptr.cpu()), (A_num_rows, A_num_cols)
)
for i in range(batch_size)
]
Expand All @@ -47,17 +44,16 @@ def check_lu_solver(
print("b[0]:\n", b[0])

AtA_csr = [(a.T @ a).tocsr() for a in A_csr]
AtA_rowPtr = torch.tensor(AtA_csr[0].indptr).cuda()
AtA_colInd = torch.tensor(AtA_csr[0].indices).cuda()
AtA_row_ptr = torch.tensor(AtA_csr[0].indptr).cuda()
AtA_col_ind = torch.tensor(AtA_csr[0].indices).cuda()
AtA_val = torch.tensor(np.array([m.data for m in AtA_csr])).cuda()
AtA_num_rows = AtA_rowPtr.size(0) - 1
AtA_num_rows = AtA_row_ptr.size(0) - 1
AtA_num_cols = AtA_num_rows
AtA_nnz = AtA_colInd.size(0) # noqa: F841

if verbose:
print("AtA[0]:\n", AtA_csr[0].todense())

slv = CusolverLUSolver(init_batch_size, AtA_num_cols, AtA_rowPtr, AtA_colInd)
slv = CusolverLUSolver(init_batch_size, AtA_num_cols, AtA_row_ptr, AtA_col_ind)
singularities = slv.factor(AtA_val)

if verbose:
Expand Down
34 changes: 15 additions & 19 deletions tests/extlib/test_mat_mult.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch # needed for import of Torch C++ extensions to work
from scipy.sparse import csr_matrix

from theseus.utils import random_sparse_binary_matrix
from theseus.utils import random_sparse_matrix


def check_mat_mult(batch_size, num_rows, num_cols, fill, verbose=False):
Expand All @@ -18,19 +18,15 @@ def check_mat_mult(batch_size, num_rows, num_cols, fill, verbose=False):

rng = torch.Generator()
rng.manual_seed(0)
A_skel = random_sparse_binary_matrix(
num_rows, num_cols, fill, min_entries_per_col=3, rng=rng
A_col_ind, A_row_ptr, A_val, _ = random_sparse_matrix(
batch_size, num_rows, num_cols, fill, 3, rng, "cuda:0", int_dtype=torch.int
)
A_num_cols = num_cols
A_rowPtr = torch.tensor(A_skel.indptr, dtype=torch.int).cuda()
A_colInd = torch.tensor(A_skel.indices, dtype=torch.int).cuda()
A_num_rows = A_rowPtr.size(0) - 1
A_nnz = A_colInd.size(0)
A_val = torch.rand((batch_size, A_nnz), dtype=torch.double).cuda()
A_num_rows = A_row_ptr.size(0) - 1

A_csr = [
csr_matrix(
(A_val[i].cpu(), A_colInd.cpu(), A_rowPtr.cpu()), (A_num_rows, A_num_cols)
(A_val[i].cpu(), A_col_ind.cpu(), A_row_ptr.cpu()), (A_num_rows, A_num_cols)
)
for i in range(batch_size)
]
Expand All @@ -39,21 +35,21 @@ def check_mat_mult(batch_size, num_rows, num_cols, fill, verbose=False):

# test At * A
AtA_csr = [(a.T @ a).tocsr() for a in A_csr]
AtA_rowPtr = torch.tensor(AtA_csr[0].indptr).cuda()
AtA_colInd = torch.tensor(AtA_csr[0].indices).cuda()
AtA_row_ptr = torch.tensor(AtA_csr[0].indptr).cuda()
AtA_col_ind = torch.tensor(AtA_csr[0].indices).cuda()
AtA_val = torch.tensor(np.array([m.data for m in AtA_csr])).cuda()
AtA_num_rows = AtA_rowPtr.size(0) - 1
AtA_num_rows = AtA_row_ptr.size(0) - 1
AtA_num_cols = AtA_num_rows

if verbose:
print("\nAtA[0]:\n", AtA_csr[0].todense())

res = mult_MtM(batch_size, A_rowPtr, A_colInd, A_val, AtA_rowPtr, AtA_colInd)
res = mult_MtM(batch_size, A_row_ptr, A_col_ind, A_val, AtA_row_ptr, AtA_col_ind)
if verbose:
print(
"res[0]:\n",
csr_matrix(
(res[0].cpu(), AtA_colInd.cpu(), AtA_rowPtr.cpu()),
(res[0].cpu(), AtA_col_ind.cpu(), AtA_row_ptr.cpu()),
(AtA_num_rows, AtA_num_cols),
).todense(),
)
Expand All @@ -65,7 +61,7 @@ def check_mat_mult(batch_size, num_rows, num_cols, fill, verbose=False):
np.array(
[
csr_matrix(
(res[x].cpu(), AtA_colInd.cpu(), AtA_rowPtr.cpu()),
(res[x].cpu(), AtA_col_ind.cpu(), AtA_row_ptr.cpu()),
(AtA_num_rows, AtA_num_cols),
).diagonal()
for x in range(batch_size)
Expand All @@ -74,12 +70,12 @@ def check_mat_mult(batch_size, num_rows, num_cols, fill, verbose=False):
)
alpha = 0.3 * torch.rand(batch_size, dtype=torch.double).cuda()
beta = 0.7 * torch.rand(batch_size, dtype=torch.double).cuda()
apply_damping(batch_size, AtA_num_cols, AtA_rowPtr, AtA_colInd, res, alpha, beta)
apply_damping(batch_size, AtA_num_cols, AtA_row_ptr, AtA_col_ind, res, alpha, beta)
new_diagonals = torch.tensor(
np.array(
[
csr_matrix(
(res[x].cpu(), AtA_colInd.cpu(), AtA_rowPtr.cpu()),
(res[x].cpu(), AtA_col_ind.cpu(), AtA_row_ptr.cpu()),
(AtA_num_rows, AtA_num_cols),
).diagonal()
for x in range(batch_size)
Expand All @@ -97,7 +93,7 @@ def check_mat_mult(batch_size, num_rows, num_cols, fill, verbose=False):
np.array([A_csr[i] @ v[i].cpu() for i in range(batch_size)])
).cuda()

A_v_test = mat_vec(batch_size, A_num_cols, A_rowPtr, A_colInd, A_val, v)
A_v_test = mat_vec(batch_size, A_num_cols, A_row_ptr, A_col_ind, A_val, v)

if verbose:
print("A_v:", A_v)
Expand All @@ -111,7 +107,7 @@ def check_mat_mult(batch_size, num_rows, num_cols, fill, verbose=False):
np.array([A_csr[i].T @ w[i].cpu() for i in range(batch_size)])
).cuda()

At_w_test = tmat_vec(batch_size, A_num_cols, A_rowPtr, A_colInd, A_val, w)
At_w_test = tmat_vec(batch_size, A_num_cols, A_row_ptr, A_col_ind, A_val, w)

if verbose:
print("A_w:", At_w)
Expand Down
4 changes: 2 additions & 2 deletions tests/optimizer/autograd/test_baspacho_sparse_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def check_sparse_backward_step(
linearization.A_val,
linearization.b,
linearization.structure(),
solver.A_rowPtr,
solver.A_colInd,
solver.A_row_ptr,
solver.A_col_ind,
solver.symbolic_decomposition,
damping_alpha_beta,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/optimizer/autograd/test_lu_cuda_sparse_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def test_sparse_backward_step():
linearization.A_val,
linearization.b,
linearization.structure(),
solver.A_rowPtr,
solver.A_colInd,
solver.A_row_ptr,
solver.A_col_ind,
solver._solver_contexts[solver._last_solver_context],
damping_alpha_beta,
False, # it's the same matrix, so no overwrite problems
Expand Down
56 changes: 33 additions & 23 deletions theseus/optimizer/autograd/baspacho_sparse_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,44 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple, Optional
from typing import Any, Tuple, Optional
import torch

from ..linear_system import SparseStructure
from theseus.utils.sparse_matrix_utils import mat_vec, tmat_vec

_BaspachoSolveFunctionBwdReturnType = Tuple[
torch.Tensor, torch.Tensor, None, None, None, None, None, None
]


class BaspachoSolveFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, *args, **kwargs):
def forward( # type: ignore
ctx: Any,
A_val: torch.Tensor,
b: torch.Tensor,
sparse_structure: SparseStructure,
A_row_ptr: torch.Tensor,
A_col_ind: torch.Tensor,
symbolic_decomposition: Any, # actually SymbolicDecomposition
damping_alpha_beta: Optional[Tuple[torch.Tensor, torch.Tensor]],
) -> torch.Tensor:
from theseus.extlib.baspacho_solver import SymbolicDecomposition

A_val: torch.Tensor = args[0]
b: torch.Tensor = args[1]
sparse_structure: SparseStructure = args[2]
A_rowPtr: torch.Tensor = args[3]
A_colInd: torch.Tensor = args[4]
symbolic_decomposition: SymbolicDecomposition = args[5]
damping_alpha_beta: Optional[Tuple[torch.Tensor, torch.Tensor]] = args[6]
assert isinstance(symbolic_decomposition, SymbolicDecomposition)

batch_size = A_val.shape[0]

numeric_decomposition = symbolic_decomposition.create_numeric_decomposition(
batch_size
)
numeric_decomposition.add_MtM(A_val, A_rowPtr, A_colInd)
numeric_decomposition.add_MtM(A_val, A_row_ptr, A_col_ind)
if damping_alpha_beta is not None:
numeric_decomposition.damp(*damping_alpha_beta)
numeric_decomposition.factor()

A_args = sparse_structure.num_cols, A_rowPtr, A_colInd, A_val
A_args = sparse_structure.num_cols, A_row_ptr, A_col_ind, A_val
Atb = tmat_vec(batch_size, *A_args, b)

x = Atb.clone()
Expand All @@ -41,8 +48,8 @@ def forward(ctx, *args, **kwargs):
ctx.b = b
ctx.x = x
ctx.A_val = A_val
ctx.A_rowPtr = A_rowPtr
ctx.A_colInd = A_colInd
ctx.A_row_ptr = A_row_ptr
ctx.A_col_ind = A_col_ind
ctx.sparse_structure = sparse_structure
ctx.numeric_decomposition = numeric_decomposition
ctx.damping_alpha_beta = damping_alpha_beta
Expand Down Expand Up @@ -99,30 +106,33 @@ def forward(ctx, *args, **kwargs):
# 2 times the scalar product of A's an (A')'s j-th colum. Therefore
# (A')'s j-th colum is multiplying A's j-th colum by 2*H[j]*alpha*x[j]
@staticmethod
def backward(ctx, grad_output):
def backward( # type: ignore
ctx: Any, grad_output: torch.Tensor
) -> _BaspachoSolveFunctionBwdReturnType:
batch_size = grad_output.shape[0]
targs = {"dtype": grad_output.dtype, "device": grad_output.device}

H = grad_output.clone()
ctx.numeric_decomposition.solve(H) # solve in place

A_args = ctx.sparse_structure.num_cols, ctx.A_rowPtr, ctx.A_colInd, ctx.A_val
A_args = ctx.sparse_structure.num_cols, ctx.A_row_ptr, ctx.A_col_ind, ctx.A_val
AH = mat_vec(batch_size, *A_args, H)
b_Ax = ctx.b - mat_vec(batch_size, *A_args, ctx.x)

# now we fill values of a matrix with structure identical to A with
# selected entries from the difference of tensor products:
# b_Ax (X) H - AH (X) x
# NOTE: this row-wise manipulation can be much faster in C++ or Cython
A_colInd = ctx.sparse_structure.col_ind
A_rowPtr = ctx.sparse_structure.row_ptr
A_col_ind = ctx.sparse_structure.col_ind
A_row_ptr = ctx.sparse_structure.row_ptr
batch_size = grad_output.shape[0]
A_grad = torch.empty(
size=(batch_size, len(A_colInd)), **targs
size=(batch_size, len(A_col_ind)),
dtype=grad_output.dtype,
device=grad_output.device,
) # return value, A's grad
for r in range(len(A_rowPtr) - 1):
start, end = A_rowPtr[r], A_rowPtr[r + 1]
columns = A_colInd[start:end] # col indices, for this row
for r in range(len(A_row_ptr) - 1):
start, end = A_row_ptr[r], A_row_ptr[r + 1]
columns = A_col_ind[start:end] # col indices, for this row
A_grad[:, start:end] = (
b_Ax[:, r].unsqueeze(1) * H[:, columns]
- AH[:, r].unsqueeze(1) * ctx.x[:, columns]
Expand All @@ -135,6 +145,6 @@ def backward(ctx, grad_output):
):
alpha = ctx.damping_alpha_beta[0].view(-1, 1)
alpha2Hx = (alpha * 2.0) * H * ctx.x # componentwise product
A_grad -= ctx.A_val * alpha2Hx[:, ctx.A_colInd.type(torch.long)]
A_grad -= ctx.A_val * alpha2Hx[:, ctx.A_col_ind.type(torch.long)]

return A_grad, AH, None, None, None, None, None, None
Loading