-
Notifications
You must be signed in to change notification settings - Fork 129
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
cusolver based batched LU solver (#22)
* cublas-based sparse LU solver class * update cuda installs in ci * add test to ci * add C++ extensions to gitignore Co-authored-by: Maurizio Monge <[email protected]>
- Loading branch information
Showing
10 changed files
with
605 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
// Copyright (c) Meta Platforms, Inc. and affiliates. | ||
// | ||
// This source code is licensed under the MIT license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
#pragma once | ||
|
||
#include "cusolver_sp_defs.h" | ||
#include <ATen/cuda/Exceptions.h> | ||
#include <ATen/cuda/detail/DeviceThreadHandles.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
|
||
// functions are defined in this headers are inline so this can be included multiple times | ||
// in units compiled independently (such as Torch extensions formed by one .cu/.cpp file) | ||
namespace theseus::cusolver_sp { | ||
|
||
const char* cusolverGetErrorMessage(cusolverStatus_t status) { | ||
switch (status) { | ||
case CUSOLVER_STATUS_SUCCESS: return "CUSOLVER_STATUS_SUCCES"; | ||
case CUSOLVER_STATUS_NOT_INITIALIZED: return "CUSOLVER_STATUS_NOT_INITIALIZED"; | ||
case CUSOLVER_STATUS_ALLOC_FAILED: return "CUSOLVER_STATUS_ALLOC_FAILED"; | ||
case CUSOLVER_STATUS_INVALID_VALUE: return "CUSOLVER_STATUS_INVALID_VALUE"; | ||
case CUSOLVER_STATUS_ARCH_MISMATCH: return "CUSOLVER_STATUS_ARCH_MISMATCH"; | ||
case CUSOLVER_STATUS_EXECUTION_FAILED: return "CUSOLVER_STATUS_EXECUTION_FAILED"; | ||
case CUSOLVER_STATUS_INTERNAL_ERROR: return "CUSOLVER_STATUS_INTERNAL_ERROR"; | ||
case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED: return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED"; | ||
default: return "Unknown cusolver error number"; | ||
} | ||
} | ||
|
||
void createCusolverSpHandle(cusolverSpHandle_t *handle) { | ||
CUSOLVER_CHECK(cusolverSpCreate(handle)); | ||
} | ||
|
||
// The switch below look weird, but we will be adopting the same policy as for CusolverDn handle in Torch source | ||
void destroyCusolverSpHandle(cusolverSpHandle_t handle) { | ||
// this is because of something dumb in the ordering of | ||
// destruction. Sometimes atexit, the cuda context (or something) | ||
// would already be destroyed by the time this gets destroyed. It | ||
// happens in fbcode setting. @colesbury and @soumith decided to not destroy | ||
// the handle as a workaround. | ||
// - Comments of @soumith copied from cuDNN handle pool implementation | ||
#ifdef NO_CUDNN_DESTROY_HANDLE | ||
#else | ||
cusolverSpDestroy(handle); | ||
#endif | ||
} | ||
|
||
using CuSolverSpPoolType = at::cuda::DeviceThreadHandlePool<cusolverSpHandle_t, createCusolverSpHandle, destroyCusolverSpHandle>; | ||
|
||
cusolverSpHandle_t getCurrentCUDASolverSpHandle() { | ||
int device; | ||
AT_CUDA_CHECK(cudaGetDevice(&device)); | ||
|
||
// Thread local PoolWindows are lazily-initialized | ||
// to avoid initialization issues that caused hangs on Windows. | ||
// See: https://github.com/pytorch/pytorch/pull/22405 | ||
// This thread local unique_ptrs will be destroyed when the thread terminates, | ||
// releasing its reserved handles back to the pool. | ||
static auto pool = std::make_shared<CuSolverSpPoolType>(); | ||
thread_local std::unique_ptr<CuSolverSpPoolType::PoolWindow> myPoolWindow(pool->newPoolWindow()); | ||
|
||
auto handle = myPoolWindow->reserve(device); | ||
auto stream = c10::cuda::getCurrentCUDAStream(); | ||
CUSOLVER_CHECK(cusolverSpSetStream(handle, stream)); | ||
return handle; | ||
} | ||
|
||
} // namespace theseus::cusolver_sp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// Copyright (c) Meta Platforms, Inc. and affiliates. | ||
// | ||
// This source code is licensed under the MIT license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
#pragma once | ||
|
||
#include <cusolverSp.h> | ||
|
||
#define CUSOLVER_CHECK(EXPR) \ | ||
do { \ | ||
cusolverStatus_t __err = EXPR; \ | ||
TORCH_CHECK(__err == CUSOLVER_STATUS_SUCCESS, \ | ||
"cusolver error: ", \ | ||
theseus::cusolver_sp::cusolverGetErrorMessage(__err), \ | ||
", when calling `" #EXPR "`"); \ | ||
} while (0) | ||
|
||
namespace theseus::cusolver_sp { | ||
|
||
const char* cusolverGetErrorMessage(cusolverStatus_t status); | ||
|
||
cusolverSpHandle_t getCurrentCUDASolverSpHandle(); | ||
|
||
} // namespace theseus::cusolver_sp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import numpy as np | ||
import pytest # noqa: F401 | ||
import torch # needed for import of Torch C++ extensions to work | ||
from scipy.sparse import csr_matrix | ||
|
||
from theseus.utils import random_sparse_binary_matrix | ||
|
||
|
||
# ideally we would like to support batch_size <= init_batch_size, but | ||
# because of limitations of cublas those have to be always identical | ||
def check_lu_solver( | ||
init_batch_size, batch_size, num_rows, num_cols, fill, verbose=False | ||
): | ||
# this is necessary assumption, so that the hessian is full rank | ||
assert num_rows >= num_cols | ||
|
||
if not torch.cuda.is_available(): | ||
return | ||
from theseus.extlib.cusolver_lu_solver import CusolverLUSolver | ||
|
||
A_skel = random_sparse_binary_matrix( | ||
num_rows, num_cols, fill, min_entries_per_col=3 | ||
) | ||
A_num_cols = num_cols | ||
A_rowPtr = torch.tensor(A_skel.indptr, dtype=torch.int).cuda() | ||
A_colInd = torch.tensor(A_skel.indices, dtype=torch.int).cuda() | ||
A_num_rows = A_rowPtr.size(0) - 1 | ||
A_nnz = A_colInd.size(0) | ||
A_val = torch.rand((batch_size, A_nnz), dtype=torch.double).cuda() | ||
b = torch.rand((batch_size, A_num_rows), dtype=torch.double).cuda() | ||
|
||
A_csr = [ | ||
csr_matrix( | ||
(A_val[i].cpu(), A_colInd.cpu(), A_rowPtr.cpu()), (A_num_rows, A_num_cols) | ||
) | ||
for i in range(batch_size) | ||
] | ||
if verbose: | ||
print("A[0]:\n", A_csr[0].todense()) | ||
print("b[0]:\n", b[0]) | ||
|
||
AtA_csr = [(a.T @ a).tocsr() for a in A_csr] | ||
AtA_rowPtr = torch.tensor(AtA_csr[0].indptr).cuda() | ||
AtA_colInd = torch.tensor(AtA_csr[0].indices).cuda() | ||
AtA_val = torch.tensor(np.array([m.data for m in AtA_csr])).cuda() | ||
AtA_num_rows = AtA_rowPtr.size(0) - 1 | ||
AtA_num_cols = AtA_num_rows | ||
AtA_nnz = AtA_colInd.size(0) # noqa: F841 | ||
|
||
if verbose: | ||
print("AtA[0]:\n", AtA_csr[0].todense()) | ||
|
||
slv = CusolverLUSolver(init_batch_size, AtA_num_cols, AtA_rowPtr, AtA_colInd) | ||
singularities = slv.factor(AtA_val) | ||
|
||
if verbose: | ||
print("singularities:", singularities) | ||
|
||
b = torch.rand((batch_size, A_num_rows), dtype=torch.double).cuda() | ||
Atb = torch.tensor( | ||
np.array([A_csr[i].T @ b[i].cpu().numpy() for i in range(batch_size)]) | ||
).cuda() | ||
if verbose: | ||
print("Atb[0]:", Atb[0]) | ||
|
||
sol = Atb.clone() | ||
slv.solve(sol) | ||
if verbose: | ||
print("x[0]:", sol[0]) | ||
|
||
residuals = [ | ||
AtA_csr[i] @ sol[i].cpu().numpy() - Atb[i].cpu().numpy() | ||
for i in range(batch_size) | ||
] | ||
if verbose: | ||
print("residual[0]:", residuals[0]) | ||
|
||
assert all(np.linalg.norm(res) < 1e-10 for res in residuals) | ||
|
||
|
||
def test_lu_solver_1(): | ||
check_lu_solver(init_batch_size=5, batch_size=5, num_rows=50, num_cols=30, fill=0.2) | ||
|
||
|
||
def test_lu_solver_2(): | ||
check_lu_solver( | ||
init_batch_size=5, batch_size=5, num_rows=150, num_cols=60, fill=0.2 | ||
) | ||
|
||
|
||
def test_lu_solver_3(): | ||
check_lu_solver( | ||
init_batch_size=10, batch_size=10, num_rows=300, num_cols=90, fill=0.2 | ||
) | ||
|
||
|
||
def test_lu_solver_4(): | ||
check_lu_solver(init_batch_size=5, batch_size=5, num_rows=50, num_cols=30, fill=0.1) | ||
|
||
|
||
def test_lu_solver_5(): | ||
check_lu_solver( | ||
init_batch_size=5, batch_size=5, num_rows=150, num_cols=60, fill=0.1 | ||
) | ||
|
||
|
||
def test_lu_solver_6(): | ||
check_lu_solver( | ||
init_batch_size=10, batch_size=10, num_rows=300, num_cols=90, fill=0.1 | ||
) | ||
|
||
|
||
# would like to test when irregular batch_size < init_batch_size, | ||
# but this is currently not supported by cublas, maybe in the future | ||
# def test_lu_solver_7(): | ||
# check_lu_solver(init_batch_size=10, batch_size=5, num_rows=150, num_cols=60, fill=0.2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import numpy as np | ||
from scipy.sparse import csr_matrix, lil_matrix | ||
|
||
|
||
def random_sparse_binary_matrix(rows, cols, fill, min_entries_per_col) -> csr_matrix: | ||
retv = lil_matrix((rows, cols)) | ||
|
||
if min_entries_per_col > 0: | ||
min_entries_per_col = min(rows, min_entries_per_col) | ||
rows_array = np.arange(rows) | ||
for c in range(cols): | ||
for r in np.random.choice(rows_array, min_entries_per_col): | ||
retv[r, c] = 1.0 | ||
|
||
num_entries = int(fill * rows * cols) | ||
while retv.getnnz() < num_entries: | ||
col = np.random.randint(cols) | ||
row = np.random.randint(rows) | ||
retv[row, col] = 1.0 | ||
|
||
return retv.tocsr() |