diff --git a/tests/extlib/test_baspacho.py b/tests/extlib/test_baspacho.py index 7bbad76c3..82b546a18 100644 --- a/tests/extlib/test_baspacho.py +++ b/tests/extlib/test_baspacho.py @@ -10,7 +10,7 @@ from tests.extlib.common import run_if_baspacho -from theseus.utils import random_sparse_binary_matrix, split_into_param_sizes +from theseus.utils import random_sparse_matrix, split_into_param_sizes def check_baspacho( @@ -35,17 +35,11 @@ def check_baspacho( from theseus.extlib.baspacho_solver import SymbolicDecomposition - A_skel = random_sparse_binary_matrix( - num_rows, num_cols, fill, min_entries_per_col=1, rng=rng + A_col_ind, A_row_ptr, A_val, A_skel = random_sparse_matrix( + batch_size, num_rows, num_cols, fill, 1, rng, dev ) A_num_cols = num_cols - A_rowPtr = torch.tensor(A_skel.indptr, dtype=torch.int64).to(dev) - A_colInd = torch.tensor(A_skel.indices, dtype=torch.int64).to(dev) - A_num_rows = A_rowPtr.size(0) - 1 - A_nnz = A_colInd.size(0) - A_val = torch.rand( - (batch_size, A_nnz), device=dev, dtype=torch.double, generator=rng - ) + A_num_rows = A_row_ptr.size(0) - 1 b = torch.rand( (batch_size, A_num_rows), device=dev, dtype=torch.double, generator=rng ) @@ -65,7 +59,7 @@ def check_baspacho( A_csr = [ csr_matrix( - (A_val[i].cpu(), A_colInd.cpu(), A_rowPtr.cpu()), (A_num_rows, A_num_cols) + (A_val[i].cpu(), A_col_ind.cpu(), A_row_ptr.cpu()), (A_num_rows, A_num_cols) ) for i in range(batch_size) ] @@ -86,7 +80,7 @@ def check_baspacho( ) f = s.create_numeric_decomposition(batch_size) - f.add_MtM(A_val, A_rowPtr, A_colInd) + f.add_MtM(A_val, A_row_ptr, A_col_ind) beta = 0.01 * torch.rand(batch_size, device=dev, dtype=torch.double, generator=rng) alpha = torch.rand(batch_size, device=dev, dtype=torch.double, generator=rng) f.damp(alpha, beta) diff --git a/tests/extlib/test_cusolver_lu_solver.py b/tests/extlib/test_cusolver_lu_solver.py index 5a9f063ca..2e45440f7 100644 --- a/tests/extlib/test_cusolver_lu_solver.py +++ b/tests/extlib/test_cusolver_lu_solver.py @@ -8,7 +8,7 @@ import torch # needed for import of Torch C++ extensions to work from scipy.sparse import csr_matrix -from theseus.utils import random_sparse_binary_matrix +from theseus.utils import random_sparse_matrix # ideally we would like to support batch_size <= init_batch_size, but @@ -25,20 +25,17 @@ def check_lu_solver( rng = torch.Generator() rng.manual_seed(0) - A_skel = random_sparse_binary_matrix( - num_rows, num_cols, fill, min_entries_per_col=3, rng=rng + A_col_ind, A_row_ptr, A_val, _ = random_sparse_matrix( + batch_size, num_rows, num_cols, fill, 3, rng, "cuda:0" ) A_num_cols = num_cols - A_rowPtr = torch.tensor(A_skel.indptr, dtype=torch.int).cuda() - A_colInd = torch.tensor(A_skel.indices, dtype=torch.int).cuda() - A_num_rows = A_rowPtr.size(0) - 1 - A_nnz = A_colInd.size(0) - A_val = torch.rand((batch_size, A_nnz), dtype=torch.double).cuda() + A_num_rows = A_row_ptr.size(0) - 1 + b = torch.rand((batch_size, A_num_rows), dtype=torch.double).cuda() A_csr = [ csr_matrix( - (A_val[i].cpu(), A_colInd.cpu(), A_rowPtr.cpu()), (A_num_rows, A_num_cols) + (A_val[i].cpu(), A_col_ind.cpu(), A_row_ptr.cpu()), (A_num_rows, A_num_cols) ) for i in range(batch_size) ] @@ -47,17 +44,16 @@ def check_lu_solver( print("b[0]:\n", b[0]) AtA_csr = [(a.T @ a).tocsr() for a in A_csr] - AtA_rowPtr = torch.tensor(AtA_csr[0].indptr).cuda() - AtA_colInd = torch.tensor(AtA_csr[0].indices).cuda() + AtA_row_ptr = torch.tensor(AtA_csr[0].indptr).cuda() + AtA_col_ind = torch.tensor(AtA_csr[0].indices).cuda() AtA_val = torch.tensor(np.array([m.data for m in AtA_csr])).cuda() - AtA_num_rows = AtA_rowPtr.size(0) - 1 + AtA_num_rows = AtA_row_ptr.size(0) - 1 AtA_num_cols = AtA_num_rows - AtA_nnz = AtA_colInd.size(0) # noqa: F841 if verbose: print("AtA[0]:\n", AtA_csr[0].todense()) - slv = CusolverLUSolver(init_batch_size, AtA_num_cols, AtA_rowPtr, AtA_colInd) + slv = CusolverLUSolver(init_batch_size, AtA_num_cols, AtA_row_ptr, AtA_col_ind) singularities = slv.factor(AtA_val) if verbose: diff --git a/tests/extlib/test_mat_mult.py b/tests/extlib/test_mat_mult.py index 9be25afb7..e86e6c282 100644 --- a/tests/extlib/test_mat_mult.py +++ b/tests/extlib/test_mat_mult.py @@ -8,7 +8,7 @@ import torch # needed for import of Torch C++ extensions to work from scipy.sparse import csr_matrix -from theseus.utils import random_sparse_binary_matrix +from theseus.utils import random_sparse_matrix def check_mat_mult(batch_size, num_rows, num_cols, fill, verbose=False): @@ -18,19 +18,15 @@ def check_mat_mult(batch_size, num_rows, num_cols, fill, verbose=False): rng = torch.Generator() rng.manual_seed(0) - A_skel = random_sparse_binary_matrix( - num_rows, num_cols, fill, min_entries_per_col=3, rng=rng + A_col_ind, A_row_ptr, A_val, _ = random_sparse_matrix( + batch_size, num_rows, num_cols, fill, 3, rng, "cuda:0", int_dtype=torch.int ) A_num_cols = num_cols - A_rowPtr = torch.tensor(A_skel.indptr, dtype=torch.int).cuda() - A_colInd = torch.tensor(A_skel.indices, dtype=torch.int).cuda() - A_num_rows = A_rowPtr.size(0) - 1 - A_nnz = A_colInd.size(0) - A_val = torch.rand((batch_size, A_nnz), dtype=torch.double).cuda() + A_num_rows = A_row_ptr.size(0) - 1 A_csr = [ csr_matrix( - (A_val[i].cpu(), A_colInd.cpu(), A_rowPtr.cpu()), (A_num_rows, A_num_cols) + (A_val[i].cpu(), A_col_ind.cpu(), A_row_ptr.cpu()), (A_num_rows, A_num_cols) ) for i in range(batch_size) ] @@ -39,21 +35,21 @@ def check_mat_mult(batch_size, num_rows, num_cols, fill, verbose=False): # test At * A AtA_csr = [(a.T @ a).tocsr() for a in A_csr] - AtA_rowPtr = torch.tensor(AtA_csr[0].indptr).cuda() - AtA_colInd = torch.tensor(AtA_csr[0].indices).cuda() + AtA_row_ptr = torch.tensor(AtA_csr[0].indptr).cuda() + AtA_col_ind = torch.tensor(AtA_csr[0].indices).cuda() AtA_val = torch.tensor(np.array([m.data for m in AtA_csr])).cuda() - AtA_num_rows = AtA_rowPtr.size(0) - 1 + AtA_num_rows = AtA_row_ptr.size(0) - 1 AtA_num_cols = AtA_num_rows if verbose: print("\nAtA[0]:\n", AtA_csr[0].todense()) - res = mult_MtM(batch_size, A_rowPtr, A_colInd, A_val, AtA_rowPtr, AtA_colInd) + res = mult_MtM(batch_size, A_row_ptr, A_col_ind, A_val, AtA_row_ptr, AtA_col_ind) if verbose: print( "res[0]:\n", csr_matrix( - (res[0].cpu(), AtA_colInd.cpu(), AtA_rowPtr.cpu()), + (res[0].cpu(), AtA_col_ind.cpu(), AtA_row_ptr.cpu()), (AtA_num_rows, AtA_num_cols), ).todense(), ) @@ -65,7 +61,7 @@ def check_mat_mult(batch_size, num_rows, num_cols, fill, verbose=False): np.array( [ csr_matrix( - (res[x].cpu(), AtA_colInd.cpu(), AtA_rowPtr.cpu()), + (res[x].cpu(), AtA_col_ind.cpu(), AtA_row_ptr.cpu()), (AtA_num_rows, AtA_num_cols), ).diagonal() for x in range(batch_size) @@ -74,12 +70,12 @@ def check_mat_mult(batch_size, num_rows, num_cols, fill, verbose=False): ) alpha = 0.3 * torch.rand(batch_size, dtype=torch.double).cuda() beta = 0.7 * torch.rand(batch_size, dtype=torch.double).cuda() - apply_damping(batch_size, AtA_num_cols, AtA_rowPtr, AtA_colInd, res, alpha, beta) + apply_damping(batch_size, AtA_num_cols, AtA_row_ptr, AtA_col_ind, res, alpha, beta) new_diagonals = torch.tensor( np.array( [ csr_matrix( - (res[x].cpu(), AtA_colInd.cpu(), AtA_rowPtr.cpu()), + (res[x].cpu(), AtA_col_ind.cpu(), AtA_row_ptr.cpu()), (AtA_num_rows, AtA_num_cols), ).diagonal() for x in range(batch_size) @@ -97,7 +93,7 @@ def check_mat_mult(batch_size, num_rows, num_cols, fill, verbose=False): np.array([A_csr[i] @ v[i].cpu() for i in range(batch_size)]) ).cuda() - A_v_test = mat_vec(batch_size, A_num_cols, A_rowPtr, A_colInd, A_val, v) + A_v_test = mat_vec(batch_size, A_num_cols, A_row_ptr, A_col_ind, A_val, v) if verbose: print("A_v:", A_v) @@ -111,7 +107,7 @@ def check_mat_mult(batch_size, num_rows, num_cols, fill, verbose=False): np.array([A_csr[i].T @ w[i].cpu() for i in range(batch_size)]) ).cuda() - At_w_test = tmat_vec(batch_size, A_num_cols, A_rowPtr, A_colInd, A_val, w) + At_w_test = tmat_vec(batch_size, A_num_cols, A_row_ptr, A_col_ind, A_val, w) if verbose: print("A_w:", At_w) diff --git a/tests/optimizer/autograd/test_baspacho_sparse_backward.py b/tests/optimizer/autograd/test_baspacho_sparse_backward.py index 6b6034ba0..88071678d 100644 --- a/tests/optimizer/autograd/test_baspacho_sparse_backward.py +++ b/tests/optimizer/autograd/test_baspacho_sparse_backward.py @@ -81,8 +81,8 @@ def check_sparse_backward_step( linearization.A_val, linearization.b, linearization.structure(), - solver.A_rowPtr, - solver.A_colInd, + solver.A_row_ptr, + solver.A_col_ind, solver.symbolic_decomposition, damping_alpha_beta, ) diff --git a/tests/optimizer/autograd/test_lu_cuda_sparse_backward.py b/tests/optimizer/autograd/test_lu_cuda_sparse_backward.py index 8103114b1..f44ff034b 100644 --- a/tests/optimizer/autograd/test_lu_cuda_sparse_backward.py +++ b/tests/optimizer/autograd/test_lu_cuda_sparse_backward.py @@ -61,8 +61,8 @@ def test_sparse_backward_step(): linearization.A_val, linearization.b, linearization.structure(), - solver.A_rowPtr, - solver.A_colInd, + solver.A_row_ptr, + solver.A_col_ind, solver._solver_contexts[solver._last_solver_context], damping_alpha_beta, False, # it's the same matrix, so no overwrite problems diff --git a/theseus/optimizer/autograd/baspacho_sparse_autograd.py b/theseus/optimizer/autograd/baspacho_sparse_autograd.py index b5315ada5..80cc12ed3 100644 --- a/theseus/optimizer/autograd/baspacho_sparse_autograd.py +++ b/theseus/optimizer/autograd/baspacho_sparse_autograd.py @@ -2,37 +2,44 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple, Optional +from typing import Any, Tuple, Optional import torch from ..linear_system import SparseStructure from theseus.utils.sparse_matrix_utils import mat_vec, tmat_vec +_BaspachoSolveFunctionBwdReturnType = Tuple[ + torch.Tensor, torch.Tensor, None, None, None, None, None, None +] + class BaspachoSolveFunction(torch.autograd.Function): @staticmethod - def forward(ctx, *args, **kwargs): + def forward( # type: ignore + ctx: Any, + A_val: torch.Tensor, + b: torch.Tensor, + sparse_structure: SparseStructure, + A_row_ptr: torch.Tensor, + A_col_ind: torch.Tensor, + symbolic_decomposition: Any, # actually SymbolicDecomposition + damping_alpha_beta: Optional[Tuple[torch.Tensor, torch.Tensor]], + ) -> torch.Tensor: from theseus.extlib.baspacho_solver import SymbolicDecomposition - A_val: torch.Tensor = args[0] - b: torch.Tensor = args[1] - sparse_structure: SparseStructure = args[2] - A_rowPtr: torch.Tensor = args[3] - A_colInd: torch.Tensor = args[4] - symbolic_decomposition: SymbolicDecomposition = args[5] - damping_alpha_beta: Optional[Tuple[torch.Tensor, torch.Tensor]] = args[6] + assert isinstance(symbolic_decomposition, SymbolicDecomposition) batch_size = A_val.shape[0] numeric_decomposition = symbolic_decomposition.create_numeric_decomposition( batch_size ) - numeric_decomposition.add_MtM(A_val, A_rowPtr, A_colInd) + numeric_decomposition.add_MtM(A_val, A_row_ptr, A_col_ind) if damping_alpha_beta is not None: numeric_decomposition.damp(*damping_alpha_beta) numeric_decomposition.factor() - A_args = sparse_structure.num_cols, A_rowPtr, A_colInd, A_val + A_args = sparse_structure.num_cols, A_row_ptr, A_col_ind, A_val Atb = tmat_vec(batch_size, *A_args, b) x = Atb.clone() @@ -41,8 +48,8 @@ def forward(ctx, *args, **kwargs): ctx.b = b ctx.x = x ctx.A_val = A_val - ctx.A_rowPtr = A_rowPtr - ctx.A_colInd = A_colInd + ctx.A_row_ptr = A_row_ptr + ctx.A_col_ind = A_col_ind ctx.sparse_structure = sparse_structure ctx.numeric_decomposition = numeric_decomposition ctx.damping_alpha_beta = damping_alpha_beta @@ -99,14 +106,15 @@ def forward(ctx, *args, **kwargs): # 2 times the scalar product of A's an (A')'s j-th colum. Therefore # (A')'s j-th colum is multiplying A's j-th colum by 2*H[j]*alpha*x[j] @staticmethod - def backward(ctx, grad_output): + def backward( # type: ignore + ctx: Any, grad_output: torch.Tensor + ) -> _BaspachoSolveFunctionBwdReturnType: batch_size = grad_output.shape[0] - targs = {"dtype": grad_output.dtype, "device": grad_output.device} H = grad_output.clone() ctx.numeric_decomposition.solve(H) # solve in place - A_args = ctx.sparse_structure.num_cols, ctx.A_rowPtr, ctx.A_colInd, ctx.A_val + A_args = ctx.sparse_structure.num_cols, ctx.A_row_ptr, ctx.A_col_ind, ctx.A_val AH = mat_vec(batch_size, *A_args, H) b_Ax = ctx.b - mat_vec(batch_size, *A_args, ctx.x) @@ -114,15 +122,17 @@ def backward(ctx, grad_output): # selected entries from the difference of tensor products: # b_Ax (X) H - AH (X) x # NOTE: this row-wise manipulation can be much faster in C++ or Cython - A_colInd = ctx.sparse_structure.col_ind - A_rowPtr = ctx.sparse_structure.row_ptr + A_col_ind = ctx.sparse_structure.col_ind + A_row_ptr = ctx.sparse_structure.row_ptr batch_size = grad_output.shape[0] A_grad = torch.empty( - size=(batch_size, len(A_colInd)), **targs + size=(batch_size, len(A_col_ind)), + dtype=grad_output.dtype, + device=grad_output.device, ) # return value, A's grad - for r in range(len(A_rowPtr) - 1): - start, end = A_rowPtr[r], A_rowPtr[r + 1] - columns = A_colInd[start:end] # col indices, for this row + for r in range(len(A_row_ptr) - 1): + start, end = A_row_ptr[r], A_row_ptr[r + 1] + columns = A_col_ind[start:end] # col indices, for this row A_grad[:, start:end] = ( b_Ax[:, r].unsqueeze(1) * H[:, columns] - AH[:, r].unsqueeze(1) * ctx.x[:, columns] @@ -135,6 +145,6 @@ def backward(ctx, grad_output): ): alpha = ctx.damping_alpha_beta[0].view(-1, 1) alpha2Hx = (alpha * 2.0) * H * ctx.x # componentwise product - A_grad -= ctx.A_val * alpha2Hx[:, ctx.A_colInd.type(torch.long)] + A_grad -= ctx.A_val * alpha2Hx[:, ctx.A_col_ind.type(torch.long)] return A_grad, AH, None, None, None, None, None, None diff --git a/theseus/optimizer/autograd/cholmod_sparse_autograd.py b/theseus/optimizer/autograd/cholmod_sparse_autograd.py index 22bfc9958..f5089fddf 100644 --- a/theseus/optimizer/autograd/cholmod_sparse_autograd.py +++ b/theseus/optimizer/autograd/cholmod_sparse_autograd.py @@ -2,27 +2,34 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from typing import Any, Tuple import torch from sksparse.cholmod import Factor as CholeskyDecomposition from ..linear_system import SparseStructure +_CholmodSolveFunctionBwdReturnType = Tuple[torch.Tensor, torch.Tensor, None, None, None] + class CholmodSolveFunction(torch.autograd.Function): @staticmethod - def forward(ctx, *args, **kwargs): - At_val: torch.Tensor = args[0] - b: torch.Tensor = args[1] - sparse_structure: SparseStructure = args[2] - symbolic_decomposition: CholeskyDecomposition = args[3] - damping: float = args[4] - + def forward( # type: ignore + ctx: Any, + At_val: torch.Tensor, + b: torch.Tensor, + sparse_structure: SparseStructure, + symbolic_decomposition: CholeskyDecomposition, + damping: float, + ) -> torch.Tensor: At_val_cpu = At_val.cpu().double() b_cpu = b.cpu().double() batch_size = At_val.shape[0] - targs = {"dtype": At_val.dtype, "device": "cpu"} - x_cpu = torch.empty(size=(batch_size, sparse_structure.num_cols), **targs) + x_cpu = torch.empty( + size=(batch_size, sparse_structure.num_cols), + dtype=At_val.dtype, + device="cpu", + ) cholesky_decompositions = [] for i in range(batch_size): @@ -88,12 +95,17 @@ def forward(ctx, *args, **kwargs): # NOTE: in the torch docs the backward is also marked as "staticmethod", I think it makes sense @staticmethod - def backward(ctx, grad_output): + def backward( # type: ignore + ctx: Any, grad_output: torch.Tensor + ) -> _CholmodSolveFunctionBwdReturnType: batch_size = grad_output.shape[0] - targs = {"dtype": grad_output.dtype, "device": "cpu"} # grad_output.device} - H = torch.empty(size=(batch_size, ctx.sparse_structure.num_cols), **targs) - AH = torch.empty(size=(batch_size, ctx.sparse_structure.num_rows), **targs) + H = torch.empty( + size=(batch_size, ctx.sparse_structure.num_cols), dtype=grad_output.dtype + ) + AH = torch.empty( + size=(batch_size, ctx.sparse_structure.num_rows), dtype=grad_output.dtype + ) b_Ax = ctx.b_cpu.clone() grad_output_cpu = grad_output.cpu() @@ -115,8 +127,7 @@ def backward(ctx, grad_output): A_row_ptr = ctx.sparse_structure.row_ptr batch_size = grad_output.shape[0] A_grad = torch.empty( - size=(batch_size, len(A_col_ind)), - device="cpu", + size=(batch_size, len(A_col_ind)) ) # return value, A's grad for r in range(len(A_row_ptr) - 1): start, end = A_row_ptr[r], A_row_ptr[r + 1] diff --git a/theseus/optimizer/autograd/lu_cuda_sparse_autograd.py b/theseus/optimizer/autograd/lu_cuda_sparse_autograd.py index 11341dbef..8ceff2afd 100644 --- a/theseus/optimizer/autograd/lu_cuda_sparse_autograd.py +++ b/theseus/optimizer/autograd/lu_cuda_sparse_autograd.py @@ -2,15 +2,29 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import torch from ..linear_system import SparseStructure +_LUCudaSolveFunctionBwdReturnType = Tuple[ + torch.Tensor, torch.Tensor, None, None, None, None, None, None +] + class LUCudaSolveFunction(torch.autograd.Function): @staticmethod - def forward(ctx, *args, **kwargs): + def forward( # type: ignore + ctx: Any, + A_val: torch.Tensor, + b: torch.Tensor, + sparse_structure: SparseStructure, + A_row_ptr: torch.Tensor, + A_col_ind: torch.Tensor, + solver_context: Any, # actually CusolverLUSolver, + damping_alpha_beta: Optional[Tuple[torch.Tensor, torch.Tensor]], + check_factor_id, + ) -> torch.Tensor: if not torch.cuda.is_available(): raise RuntimeError("Cuda not available, LUCudaSolveFunction cannot be used") @@ -25,27 +39,22 @@ def forward(ctx, *args, **kwargs): f"{type(e).__name__}: {e}" ) - A_val: torch.Tensor = args[0] - b: torch.Tensor = args[1] - sparse_structure: SparseStructure = args[2] - A_rowPtr: torch.Tensor = args[3] - A_colInd: torch.Tensor = args[4] - solver_context: CusolverLUSolver = args[5] - damping_alpha_beta: Optional[Tuple[torch.Tensor, torch.Tensor]] = args[6] - check_factor_id: bool = args[7] + assert isinstance(solver_context, CusolverLUSolver) - AtA_rowPtr = solver_context.A_rowPtr - AtA_colInd = solver_context.A_colInd + AtA_row_ptr = solver_context.A_rowPtr + AtA_col_ind = solver_context.A_colInd batch_size = A_val.shape[0] - AtA = mult_MtM(batch_size, A_rowPtr, A_colInd, A_val, AtA_rowPtr, AtA_colInd) + AtA = mult_MtM( + batch_size, A_row_ptr, A_col_ind, A_val, AtA_row_ptr, AtA_col_ind + ) if damping_alpha_beta is not None: - AtA_args = sparse_structure.num_cols, AtA_rowPtr, AtA_colInd, AtA + AtA_args = sparse_structure.num_cols, AtA_row_ptr, AtA_col_ind, AtA apply_damping(batch_size, *AtA_args, *damping_alpha_beta) solver_context.factor(AtA) - A_args = sparse_structure.num_cols, A_rowPtr, A_colInd, A_val + A_args = sparse_structure.num_cols, A_row_ptr, A_col_ind, A_val Atb = tmat_vec(batch_size, *A_args, b) x = Atb.clone() solver_context.solve(x) # solve in place @@ -53,8 +62,8 @@ def forward(ctx, *args, **kwargs): ctx.b = b ctx.x = x ctx.A_val = A_val - ctx.A_rowPtr = A_rowPtr - ctx.A_colInd = A_colInd + ctx.A_row_ptr = A_row_ptr + ctx.A_col_ind = A_col_ind ctx.sparse_structure = sparse_structure ctx.solver_context = solver_context ctx.damping_alpha_beta = damping_alpha_beta @@ -114,7 +123,9 @@ def forward(ctx, *args, **kwargs): # 2 times the scalar product of A's an (A')'s j-th colum. Therefore # (A')'s j-th colum is multiplying A's j-th colum by 2*H[j]*alpha*x[j] @staticmethod - def backward(ctx, grad_output): + def backward( # type: ignore + ctx, grad_output: torch.Tensor + ) -> _LUCudaSolveFunctionBwdReturnType: if not torch.cuda.is_available(): raise RuntimeError("Cuda not available, LUCudaSolveFunction cannot be used") @@ -136,12 +147,11 @@ def backward(ctx, grad_output): ) batch_size = grad_output.shape[0] - targs = {"dtype": grad_output.dtype, "device": "cuda"} # grad_output.device} H = grad_output.clone() ctx.solver_context.solve(H) # solve in place - A_args = ctx.sparse_structure.num_cols, ctx.A_rowPtr, ctx.A_colInd, ctx.A_val + A_args = ctx.sparse_structure.num_cols, ctx.A_row_ptr, ctx.A_col_ind, ctx.A_val AH = mat_vec(batch_size, *A_args, H) b_Ax = ctx.b - mat_vec(batch_size, *A_args, ctx.x) @@ -149,15 +159,15 @@ def backward(ctx, grad_output): # selected entries from the difference of tensor products: # b_Ax (X) H - AH (X) x # NOTE: this row-wise manipulation can be much faster in C++ or Cython - A_colInd = ctx.sparse_structure.col_ind - A_rowPtr = ctx.sparse_structure.row_ptr + A_col_ind = ctx.sparse_structure.col_ind + A_row_ptr = ctx.sparse_structure.row_ptr batch_size = grad_output.shape[0] A_grad = torch.empty( - size=(batch_size, len(A_colInd)), **targs + size=(batch_size, len(A_col_ind)), dtype=grad_output.dtype, device="cuda" ) # return value, A's grad - for r in range(len(A_rowPtr) - 1): - start, end = A_rowPtr[r], A_rowPtr[r + 1] - columns = A_colInd[start:end] # col indices, for this row + for r in range(len(A_row_ptr) - 1): + start, end = A_row_ptr[r], A_row_ptr[r + 1] + columns = A_col_ind[start:end] # col indices, for this row A_grad[:, start:end] = ( b_Ax[:, r].unsqueeze(1) * H[:, columns] - AH[:, r].unsqueeze(1) * ctx.x[:, columns] @@ -170,6 +180,6 @@ def backward(ctx, grad_output): ): alpha = ctx.damping_alpha_beta[0].view(-1, 1) alpha2Hx = (alpha * 2.0) * H * ctx.x # componentwise product - A_grad -= ctx.A_val * alpha2Hx[:, ctx.A_colInd.type(torch.long)] + A_grad -= ctx.A_val * alpha2Hx[:, ctx.A_col_ind.type(torch.long)] return A_grad, AH, None, None, None, None, None, None diff --git a/theseus/optimizer/linear/baspacho_sparse_solver.py b/theseus/optimizer/linear/baspacho_sparse_solver.py index 57ed3f4d7..d854b8603 100644 --- a/theseus/optimizer/linear/baspacho_sparse_solver.py +++ b/theseus/optimizer/linear/baspacho_sparse_solver.py @@ -7,6 +7,7 @@ import torch +from theseus.constants import DeviceType from theseus.core import Objective from theseus.optimizer import Linearization, SparseLinearization from theseus.optimizer.autograd import BaspachoSolveFunction @@ -26,8 +27,8 @@ def __init__( objective: Objective, linearization_cls: Optional[Type[Linearization]] = None, linearization_kwargs: Optional[Dict[str, Any]] = None, - num_solver_contexts=1, - dev=DEFAULT_DEVICE, + num_solver_contexts: int = 1, + dev: DeviceType = DEFAULT_DEVICE, **kwargs, ): linearization_cls = linearization_cls or SparseLinearization @@ -45,7 +46,7 @@ def __init__( if self.linearization.structure().num_rows: self.reset(dev) - def reset(self, dev=DEFAULT_DEVICE): + def reset(self, dev: DeviceType = DEFAULT_DEVICE): if dev == "cuda" and not torch.cuda.is_available(): raise RuntimeError( "BaspachoSparseSolver: Cuda requested (dev='cuda') but not\n" @@ -63,10 +64,10 @@ def reset(self, dev=DEFAULT_DEVICE): ) # convert to tensors for accelerated Mt x M operation - self.A_rowPtr = torch.tensor( + self.A_row_ptr = torch.tensor( self.linearization.structure().row_ptr, dtype=torch.int64 ).to(dev) - self.A_colInd = torch.tensor( + self.A_col_ind = torch.tensor( self.linearization.structure().col_ind, dtype=torch.int64 ).to(dev) @@ -124,8 +125,8 @@ def solve( self.linearization.A_val, self.linearization.b, self.linearization.structure(), - self.A_rowPtr, - self.A_colInd, + self.A_row_ptr, + self.A_col_ind, self.symbolic_decomposition, damping_alpha_beta, ) diff --git a/theseus/optimizer/linear/lu_cuda_sparse_solver.py b/theseus/optimizer/linear/lu_cuda_sparse_solver.py index 20381d806..5c6871967 100644 --- a/theseus/optimizer/linear/lu_cuda_sparse_solver.py +++ b/theseus/optimizer/linear/lu_cuda_sparse_solver.py @@ -21,7 +21,7 @@ def __init__( objective: Objective, linearization_cls: Optional[Type[Linearization]] = None, linearization_kwargs: Optional[Dict[str, Any]] = None, - num_solver_contexts=1, + num_solver_contexts: int = 1, batch_size: Optional[int] = None, auto_reset: bool = True, **kwargs, @@ -64,10 +64,10 @@ def reset(self, batch_size: int = 16): f"{type(e).__name__}: {e}" ) - self.A_rowPtr = torch.tensor( + self.A_row_ptr = torch.tensor( self.linearization.structure().row_ptr, dtype=torch.int32 ).cuda() - self.A_colInd = torch.tensor( + self.A_col_ind = torch.tensor( self.linearization.structure().col_ind, dtype=torch.int32 ).cuda() At_mock = self.linearization.structure().mock_csc_transpose() @@ -76,14 +76,14 @@ def reset(self, batch_size: int = 16): # symbolic decomposition depending on the sparse structure, done with mock data # HACK: we generate several context, as by cublas the symbolic_decomposition is # also a context for factorization, and the two cannot be separated - AtA_rowPtr = torch.tensor(AtA_mock.indptr, dtype=torch.int32).cuda() - AtA_colInd = torch.tensor(AtA_mock.indices, dtype=torch.int32).cuda() + AtA_row_ptr = torch.tensor(AtA_mock.indptr, dtype=torch.int32).cuda() + AtA_col_ind = torch.tensor(AtA_mock.indices, dtype=torch.int32).cuda() self._solver_contexts: List[CusolverLUSolver] = [ CusolverLUSolver( batch_size, AtA_mock.shape[1], - AtA_rowPtr, - AtA_colInd, + AtA_row_ptr, + AtA_col_ind, ) for _ in range(self._num_solver_contexts) ] @@ -124,8 +124,8 @@ def solve( self.linearization.A_val, self.linearization.b, self.linearization.structure(), - self.A_rowPtr, - self.A_colInd, + self.A_row_ptr, + self.A_col_ind, self._solver_contexts[self._last_solver_context], damping_alpha_beta, True, diff --git a/theseus/optimizer/linear_system.py b/theseus/optimizer/linear_system.py index fd8d1d09c..ff738ed1c 100644 --- a/theseus/optimizer/linear_system.py +++ b/theseus/optimizer/linear_system.py @@ -6,6 +6,7 @@ import abc import numpy as np +import torch from scipy.sparse import csc_matrix, csr_matrix @@ -24,21 +25,21 @@ def __init__( self.num_cols = num_cols self.dtype = dtype - def csr_straight(self, val): + def csr_straight(self, val: torch.Tensor) -> csr_matrix: return csr_matrix( (val, self.col_ind, self.row_ptr), (self.num_rows, self.num_cols), dtype=self.dtype, ) - def csc_transpose(self, val): + def csc_transpose(self, val: torch.Tensor) -> csc_matrix: return csc_matrix( (val, self.col_ind, self.row_ptr), (self.num_cols, self.num_rows), dtype=self.dtype, ) - def mock_csc_transpose(self): + def mock_csc_transpose(self) -> csc_matrix: return csc_matrix( (np.ones(len(self.col_ind), dtype=self.dtype), self.col_ind, self.row_ptr), (self.num_cols, self.num_rows), diff --git a/theseus/optimizer/sparse_linearization.py b/theseus/optimizer/sparse_linearization.py index 8b0ba1fab..a634f12f9 100644 --- a/theseus/optimizer/sparse_linearization.py +++ b/theseus/optimizer/sparse_linearization.py @@ -146,17 +146,17 @@ def _ata_impl(self) -> torch.Tensor: def _atb_impl(self) -> torch.Tensor: if self._Atb is None: - A_rowPtr = torch.tensor(self.A_row_ptr, dtype=torch.int32).to( + A_row_ptr = torch.tensor(self.A_row_ptr, dtype=torch.int32).to( self.objective.device ) - A_colInd = A_rowPtr.new_tensor(self.A_col_ind) + A_col_ind = A_row_ptr.new_tensor(self.A_col_ind) # unsqueeze at the end for consistency with DenseLinearization self._Atb = tmat_vec( self.objective.batch_size, self.num_cols, - A_rowPtr, - A_colInd, + A_row_ptr, + A_col_ind, self.A_val, self.b, ).unsqueeze(2) diff --git a/theseus/utils/__init__.py b/theseus/utils/__init__.py index 1529090ae..7180b8616 100644 --- a/theseus/utils/__init__.py +++ b/theseus/utils/__init__.py @@ -4,9 +4,10 @@ # LICENSE file in the root directory of this source tree. from .sparse_matrix_utils import ( - tmat_vec, mat_vec, + random_sparse_matrix, random_sparse_binary_matrix, split_into_param_sizes, + tmat_vec, ) from .utils import build_mlp, gather_from_rows_cols, numeric_jacobian diff --git a/theseus/utils/sparse_matrix_utils.py b/theseus/utils/sparse_matrix_utils.py index 6bd58046c..1ad510193 100644 --- a/theseus/utils/sparse_matrix_utils.py +++ b/theseus/utils/sparse_matrix_utils.py @@ -2,19 +2,26 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import List +from typing import List, Tuple import numpy as np import torch from scipy.sparse import csc_matrix, csr_matrix, lil_matrix -def _mat_vec_cpu(batch_size, num_cols, A_rowPtr, A_colInd, A_val, v): +def _mat_vec_cpu( + batch_size: int, + num_cols: int, + A_row_ptr: torch.Tensor, + A_col_ind: torch.Tensor, + A_val: torch.Tensor, + v: torch.Tensor, +) -> torch.Tensor: assert batch_size == A_val.shape[0] - num_rows = len(A_rowPtr) - 1 + num_rows = len(A_row_ptr) - 1 retv_data = np.array( [ - csr_matrix((A_val[i].numpy(), A_colInd, A_rowPtr), (num_rows, num_cols)) + csr_matrix((A_val[i].numpy(), A_col_ind, A_row_ptr), (num_rows, num_cols)) * v[i] for i in range(batch_size) ], @@ -23,8 +30,15 @@ def _mat_vec_cpu(batch_size, num_cols, A_rowPtr, A_colInd, A_val, v): return torch.tensor(retv_data, dtype=torch.float64) -def mat_vec(batch_size, num_cols, A_rowPtr, A_colInd, A_val, v): - if A_rowPtr.device.type == "cuda": +def mat_vec( + batch_size: int, + num_cols: int, + A_row_ptr: torch.Tensor, + A_col_ind: torch.Tensor, + A_val: torch.Tensor, + v: torch.Tensor, +) -> torch.Tensor: + if A_row_ptr.device.type == "cuda": try: from theseus.extlib.mat_mult import mat_vec as mat_vec_cuda except Exception as e: @@ -34,17 +48,24 @@ def mat_vec(batch_size, num_cols, A_rowPtr, A_colInd, A_val, v): "is installed with Cuda support (export CUDA_HOME=...)\n" f"{type(e).__name__}: {e}" ) - return mat_vec_cuda(batch_size, num_cols, A_rowPtr, A_colInd, A_val, v) + return mat_vec_cuda(batch_size, num_cols, A_row_ptr, A_col_ind, A_val, v) else: - return _mat_vec_cpu(batch_size, num_cols, A_rowPtr, A_colInd, A_val, v) + return _mat_vec_cpu(batch_size, num_cols, A_row_ptr, A_col_ind, A_val, v) -def _tmat_vec_cpu(batch_size, num_cols, A_rowPtr, A_colInd, A_val, v): +def _tmat_vec_cpu( + batch_size: int, + num_cols: int, + A_row_ptr: torch.Tensor, + A_col_ind: torch.Tensor, + A_val: torch.Tensor, + v: torch.Tensor, +) -> torch.Tensor: assert batch_size == A_val.shape[0] - num_rows = len(A_rowPtr) - 1 + num_rows = len(A_row_ptr) - 1 retv_data = np.array( [ - csc_matrix((A_val[i].numpy(), A_colInd, A_rowPtr), (num_cols, num_rows)) + csc_matrix((A_val[i].numpy(), A_col_ind, A_row_ptr), (num_cols, num_rows)) * v[i] for i in range(batch_size) ], @@ -53,8 +74,15 @@ def _tmat_vec_cpu(batch_size, num_cols, A_rowPtr, A_colInd, A_val, v): return torch.tensor(retv_data, dtype=torch.float64) -def tmat_vec(batch_size, num_cols, A_rowPtr, A_colInd, A_val, v): - if A_rowPtr.device.type == "cuda": +def tmat_vec( + batch_size: int, + num_cols: int, + A_row_ptr: torch.Tensor, + A_col_ind: torch.Tensor, + A_val: torch.Tensor, + v: torch.Tensor, +): + if A_row_ptr.device.type == "cuda": try: from theseus.extlib.mat_mult import tmat_vec as tmat_vec_cuda except Exception as e: @@ -64,21 +92,25 @@ def tmat_vec(batch_size, num_cols, A_rowPtr, A_colInd, A_val, v): "is installed with Cuda support (export CUDA_HOME=...)\n" f"{type(e).__name__}: {e}" ) - return tmat_vec_cuda(batch_size, num_cols, A_rowPtr, A_colInd, A_val, v) + return tmat_vec_cuda(batch_size, num_cols, A_row_ptr, A_col_ind, A_val, v) else: - return _tmat_vec_cpu(batch_size, num_cols, A_rowPtr, A_colInd, A_val, v) + return _tmat_vec_cpu(batch_size, num_cols, A_row_ptr, A_col_ind, A_val, v) def random_sparse_binary_matrix( - rows: int, cols: int, fill: float, min_entries_per_col: int, rng: torch.Generator + num_rows: int, + num_cols: int, + fill: float, + min_entries_per_col: int, + rng: torch.Generator, ) -> csr_matrix: - retv = lil_matrix((rows, cols)) + retv = lil_matrix((num_rows, num_cols)) if min_entries_per_col > 0: - min_entries_per_col = min(rows, min_entries_per_col) - rows_array = torch.arange(rows, device=rng.device) + min_entries_per_col = min(num_rows, min_entries_per_col) + rows_array = torch.arange(num_rows, device=rng.device) rows_array_f = rows_array.to(dtype=torch.float) - for c in range(cols): + for c in range(num_cols): row_selection = rows_array[ rows_array_f.multinomial(min_entries_per_col, generator=rng) ].cpu() @@ -86,17 +118,45 @@ def random_sparse_binary_matrix( retv[r, c] = 1.0 # make sure last row is non-empty, so: len(indptr) = rows+1 - retv[rows - 1, int(torch.randint(cols, (), device=rng.device, generator=rng))] = 1.0 + retv[ + num_rows - 1, int(torch.randint(num_cols, (), device=rng.device, generator=rng)) + ] = 1.0 - num_entries = int(fill * rows * cols) + num_entries = int(fill * num_rows * num_cols) while retv.getnnz() < num_entries: - col = int(torch.randint(cols, (), device=rng.device, generator=rng)) - row = int(torch.randint(rows, (), device=rng.device, generator=rng)) + col = int(torch.randint(num_cols, (), device=rng.device, generator=rng)) + row = int(torch.randint(num_rows, (), device=rng.device, generator=rng)) retv[row, col] = 1.0 return retv.tocsr() +def random_sparse_matrix( + batch_size: int, + num_rows: int, + num_cols: int, + fill: float, + min_entries_per_col: int, + rng: torch.Generator, + device: torch.device, + int_dtype: torch.dtype = torch.int64, + float_dtype: torch.dtype = torch.double, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + A_skel = random_sparse_binary_matrix( + num_rows, num_cols, fill, min_entries_per_col=min_entries_per_col, rng=rng + ) + A_row_ptr = torch.tensor(A_skel.indptr, dtype=int_dtype).to(device) + A_col_ind = torch.tensor(A_skel.indices, dtype=int_dtype).to(device) + A_val = torch.rand( + batch_size, + A_col_ind.size(0), + device=rng.device, + dtype=float_dtype, + generator=rng, + ).to(device) + return A_col_ind, A_row_ptr, A_val, A_skel + + def split_into_param_sizes( n: int, param_size_range_min: int, param_size_range_max: int, rng: torch.Generator ) -> List[int]: