Skip to content

Commit

Permalink
Show file tree
Hide file tree
Showing 14 changed files with 429 additions and 338 deletions.
503 changes: 299 additions & 204 deletions aten/src/ATen/native/TensorConversions.cpp

Large diffs are not rendered by default.

36 changes: 30 additions & 6 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6997,51 +6997,75 @@

- func: to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor
variants: method

# Special case of to_sparse.sparse_dim with custom derivative
- func: _to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor
variants: method
dispatch:
CPU, CUDA: dense_to_sparse
SparseCPU, SparseCUDA: sparse_coo_to_sparse
SparseCsrCPU, SparseCsrCUDA: sparse_compressed_to_sparse
autogen: to_sparse.sparse_dim_out
autogen: _to_sparse.sparse_dim_out

- func: to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor
variants: method

# Special case of to_sparse with custom derivative
- func: _to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor
variants: method
dispatch:
CPU, CUDA: dense_to_sparse
SparseCPU, SparseCUDA: sparse_coo_to_sparse
SparseCsrCPU, SparseCsrCUDA: sparse_compressed_to_sparse
autogen: to_sparse.out
autogen: _to_sparse.out

- func: to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor
variants: method

# Special case of to_sparse_csr with custom derivative
- func: _to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor
variants: method
dispatch:
CPU, CUDA: dense_to_sparse_csr
SparseCPU, SparseCUDA: coo_to_sparse_csr
SparseCsrCPU, SparseCsrCUDA: sparse_compressed_to_sparse_csr
autogen: to_sparse_csr.out
autogen: _to_sparse_csr.out

- func: to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor
variants: method

# Special case of to_sparse_csc with custom derivative
- func: _to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor
variants: method
dispatch:
CPU, CUDA: dense_to_sparse_csc
SparseCPU, SparseCUDA: coo_to_sparse_csc
SparseCsrCPU, SparseCsrCUDA: sparse_compressed_to_sparse_csc
autogen: to_sparse_csc.out
autogen: _to_sparse_csc.out

- func: to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor
variants: method

# Special case of to_sparse_bsr with custom derivative
- func: _to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor
variants: method
dispatch:
CPU, CUDA: dense_to_sparse_bsr
SparseCPU, SparseCUDA: coo_to_sparse_bsr
SparseCsrCPU, SparseCsrCUDA: sparse_compressed_to_sparse_bsr
autogen: to_sparse_bsr.out
autogen: _to_sparse_bsr.out

- func: to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor
variants: method

# Special case of to_sparse_bsc with custom derivative
- func: _to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor
variants: method
dispatch:
CPU, CUDA: dense_to_sparse_bsc
SparseCPU, SparseCUDA: coo_to_sparse_bsc
SparseCsrCPU, SparseCsrCUDA: sparse_compressed_to_sparse_bsc
autogen: to_sparse_bsc.out
autogen: _to_sparse_bsc.out

- func: to_mkldnn(Tensor self, ScalarType? dtype=None) -> Tensor
variants: method
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/sparse/SparseCsrTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,7 @@ Tensor select_sparse_csr_worker(const Tensor& self, int64_t dim, int64_t index)
// Selecting sparse dimension
TORCH_CHECK(
n_batch == 0,
select_name, ": selecting sparse dimensions is not implemented for batched sparse compressed tensors.")
select_name, ": selecting sparse dimensions is not supported for batched sparse compressed tensors.")
TORCH_INTERNAL_ASSERT(dim == 0 || dim == 1);

DimVector blocksize{1, 1};
Expand Down
88 changes: 0 additions & 88 deletions aten/src/ATen/native/sparse/SparseTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -522,94 +522,6 @@ const SparseTensor& resize_as_sparse_(const SparseTensor& self, const SparseTens
return self;
}

SparseTensor dense_to_sparse(const Tensor& self, c10::optional<c10::Layout> layout, OptionalIntArrayRef blocksize, c10::optional<int64_t> dense_dim_opt) {
if (layout.has_value()) {
if (blocksize.has_value() && !(*layout == kSparseBsr || *layout == kSparseBsc)) {
AT_ERROR("to_sparse for ", self.layout(), " to ", *layout,
" conversion does not use the specified blocksize ", blocksize.value(), ".");
}
if (self.layout() == *layout) {
return self;
}
switch (*layout) {
case kStrided:
return self;
case kSparse:
return dense_to_sparse(self, self.dim() - dense_dim_opt.value_or(0));
case kSparseCsr:
return self.to_sparse_csr(dense_dim_opt);
case kSparseCsc:
return self.to_sparse_csc(dense_dim_opt);
case kSparseBsr:
if (blocksize.has_value()) {
return self.to_sparse_bsr(*blocksize, dense_dim_opt);
}
AT_ERROR("to_sparse for ", self.layout(), " to ", *layout, " conversion requires blocksize");
break;
case kSparseBsc:
if (blocksize.has_value()) {
return self.to_sparse_bsc(*blocksize, dense_dim_opt);
}
break;
AT_ERROR("to_sparse for ", self.layout(), " to ", *layout, " conversion requires blocksize");
default:
break;
}
AT_ERROR("to_sparse not implemented for ", self.layout(), " to ", *layout, " conversion");
}
return dense_to_sparse(self, self.dim() - dense_dim_opt.value_or(0));
}

SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim) {
int64_t dims = self.dim();
// TODO: it seems like sparse_dim == 0 could be supported even if self.dim() >
// 0, but this would take some work and doesn't seem particularly useful.
TORCH_CHECK(
sparse_dim > 0 || self.dim() == 0,
"sparse_dim must be >0 if dimensionality > 0");
TORCH_CHECK(
sparse_dim <= dims,
"sparse_dim must be less than or equal to self.dim()");
at::TensorOptions sparse_options = self.options().layout(kSparse);
std::vector<int64_t> sizes = self.sizes().vec();

Tensor nz = self.nonzero().transpose(0, 1);
if (nz.size(1) == 0) {
auto sparse = new_with_dims_sparse(
sparse_dim,
dims - sparse_dim,
sizes,
optTypeMetaToScalarType(sparse_options.dtype_opt()),
sparse_options.layout_opt(),
sparse_options.device_opt(),
sparse_options.pinned_memory_opt());
return sparse._coalesced_(true);
}
Tensor indices;
if (sparse_dim == dims) {
indices = nz.clone();
} else {
Tensor i = nz.narrow(0, 0, sparse_dim);
std::tie(indices, std::ignore, std::ignore) = unique_dim(i, 1);
indices = indices.contiguous(); // many sparse CUDA kernels require
// contiguity, see issue #12633
}

Tensor values;
if (self.dim() > 0) {
auto ix = toListOfOptionalTensors(indices.chunk(indices.size(0), 0));
values = self.index(ix).squeeze(0).clone(at::MemoryFormat::Preserve);
} else {
AT_ASSERT(nz.sizes().equals({0, 1}));
// In this cases, indices is a clone of nz, which is a tensor of shape (0,
// 1). Given sparse tensor invariants, values should be shape (1,)
values = self.unsqueeze(0).clone(at::MemoryFormat::Preserve);
}

Tensor sparse = at::sparse_coo_tensor(indices, values, sizes, sparse_options);
return sparse._coalesced_(true);
}

// NB: Dropped the resizeNd variants

SparseTensor& copy_sparse_wrapper_(
Expand Down
24 changes: 12 additions & 12 deletions test/expect/HasDecompTest.test_has_decomposition.expect
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,18 @@ aten::_thnn_fused_lstm_cell_backward_impl
aten::_thnn_fused_lstm_cell_backward_impl.out
aten::_to_dense
aten::_to_dense.out
aten::_to_sparse
aten::_to_sparse.out
aten::_to_sparse.sparse_dim
aten::_to_sparse.sparse_dim_out
aten::_to_sparse_bsc
aten::_to_sparse_bsc.out
aten::_to_sparse_bsr
aten::_to_sparse_bsr.out
aten::_to_sparse_csc
aten::_to_sparse_csc.out
aten::_to_sparse_csr
aten::_to_sparse_csr.out
aten::_transform_bias_rescale_qkv
aten::_transform_bias_rescale_qkv.out
aten::_transformer_encoder_layer_fwd
Expand Down Expand Up @@ -1253,18 +1265,6 @@ aten::to_mkldnn
aten::to_mkldnn.out
aten::to_padded_tensor
aten::to_padded_tensor.out
aten::to_sparse
aten::to_sparse.out
aten::to_sparse.sparse_dim
aten::to_sparse.sparse_dim_out
aten::to_sparse_bsc
aten::to_sparse_bsc.out
aten::to_sparse_bsr
aten::to_sparse_bsr.out
aten::to_sparse_csc
aten::to_sparse_csc.out
aten::to_sparse_csr
aten::to_sparse_csr.out
aten::topk
aten::topk.values
aten::transpose_
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,12 @@
("aten::reduce_scatter_tensor", datetime.date(9999, 1, 30)),
("aten::all_gather_into_tensor", datetime.date(9999, 1, 30)),
("aten::all_reduce", datetime.date(9999, 1, 30)),
("aten::to_sparse.out", datetime.date(2023, 12, 31)),
("aten::to_sparse.sparse_dim_out", datetime.date(2023, 12, 31)),
("aten::to_sparse_bsc.out", datetime.date(2023, 12, 31)),
("aten::to_sparse_bsr.out", datetime.date(2023, 12, 31)),
("aten::to_sparse_csc.out", datetime.date(2023, 12, 31)),
("aten::to_sparse_csr.out", datetime.date(2023, 12, 31)),
("aten::_structured_sparse_linear", datetime.date(2023, 7, 1)),

]
Expand Down
4 changes: 2 additions & 2 deletions test/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,8 +836,8 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
aten.ormqr.default : {c64, c128, f64, f32},
aten.ormqr.out : {c64, c128, f64, f32},
aten.tensordot.out : {c64, i8, f64, c128, i64, bf16, f32, i32, i16, u8},
aten.to_sparse.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
aten.to_sparse.sparse_dim : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
aten._to_sparse.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
aten._to_sparse.sparse_dim : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
aten._ctc_loss.default : {f32, f64}, # Shape of second output depends on data.
aten._ctc_loss.Tensor : {f32, f64}, # Shape of second output depends on data.
aten._histogramdd_bin_edges.default : {f32, f64},
Expand Down
2 changes: 1 addition & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1970,7 +1970,7 @@ def test_refs_are_in_decomp_table(self, op):
"nn.functional.embedding_bag", # sometimes errors
"nn.functional.nll_loss", # sometimes errors
"nn.functional.max_pool1d", # The tensor has a non-zero number of elements
"to_sparse", # Could not run 'aten::to_sparse' with arguments from the 'Meta' backend
"to_sparse", # Could not run 'aten::_to_sparse' with arguments from the 'Meta' backend
"tensor_split", # The tensor has a non-zero number of elements, but its data is not allocated yet
"repeat_interleave", # cannot repeat_interleave a meta tensor without output_size
"_segment_reduce.lengths", # Could not run 'aten::segment_reduce' with arguments from the 'Meta' backend.
Expand Down
38 changes: 30 additions & 8 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4519,20 +4519,24 @@ def explicit_to_sparse(x):
# not implemented conversions
if from_layout in {
torch.sparse_csr, torch.sparse_csc} and to_layout in {torch.sparse_bsr, torch.sparse_bsc} and is_batch:
with self.assertRaisesRegex(RuntimeError,
r"conversion from (Csr|Csc) to (Bsr|Bsc) for batched inputs is not implemented"):
with self.assertRaisesRegex(
RuntimeError,
r"conversion from Sparse(Csr|Csc) to Sparse(Bsr|Bsc) for batched inputs is not supported"):
t.to_sparse(layout=to_layout, blocksize=blocksize)
with self.assertRaisesRegex(RuntimeError,
r"conversion from (Csr|Csc) to (Bsr|Bsc) for batched inputs is not implemented"):
with self.assertRaisesRegex(
RuntimeError,
r"conversion from Sparse(Csr|Csc) to Sparse(Bsr|Bsc) for batched inputs is not supported"):
explicit_to_sparse(t)
continue
elif from_layout is torch.sparse_coo and to_layout in {
torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} and t.sparse_dim() != 2:
with self.assertRaisesRegex(
RuntimeError, "Only tensors with two sparse dimensions can be converted to the Sparse(Csr|Csc) layout"):
RuntimeError,
r"conversion from Sparse to .* for input tensors with sparse_dim\(\)!=2 is not supported"):
t.to_sparse(layout=to_layout, blocksize=blocksize)
with self.assertRaisesRegex(
RuntimeError, "Only tensors with two sparse dimensions can be converted to the Sparse(Csr|Csc) layout"):
RuntimeError,
r"conversion from Sparse to .* for input tensors with sparse_dim\(\)!=2 is not supported"):
explicit_to_sparse(t)
continue
elif from_layout in {torch.sparse_csr, torch.sparse_csc,
Expand All @@ -4548,12 +4552,12 @@ def explicit_to_sparse(x):
(torch.sparse_bsr, torch.sparse_csr), (torch.sparse_bsr, torch.sparse_csc)}:
with self.assertRaisesRegex(
RuntimeError,
r"sparse_compressed_to_sparse_(csr|csc|bsr|bsc) expected\s*(Sparse(Csc|Csr)[,]|)\s*Sparse(Csr|Bsr)"
r"sparse_compressed_to_sparse_(csr|csc|bsr|bsc): expected\s*(Sparse(Csc|Csr)[,]|)\s*Sparse(Csr|Bsr)"
" or Sparse(Csc|Bsc) layout but got Sparse(Csr|Csc|Bsr|Bsc)"):
t.to_sparse(layout=to_layout, blocksize=blocksize)
with self.assertRaisesRegex(
RuntimeError,
r"sparse_compressed_to_sparse_(csr|csc|bsr|bsc) expected\s*(Sparse(Csc|Csr)[,]|)\s*Sparse(Csr|Bsr)"
r"sparse_compressed_to_sparse_(csr|csc|bsr|bsc): expected\s*(Sparse(Csc|Csr)[,]|)\s*Sparse(Csr|Bsr)"
" or Sparse(Csc|Bsc) layout but got Sparse(Csr|Csc|Bsr|Bsc)"):
explicit_to_sparse(t)
self.skipTest('NOT IMPL')
Expand Down Expand Up @@ -4883,6 +4887,24 @@ def run_test(batch_shape, m, n, k, device, dtype, dtype_out, add_bias, activatio
k = 2 ** k * 128
run_test(batch_shape, m, n, k, device, dtype, dtype_out[dtype], add_bias, activation)

@onlyCPU
@all_sparse_layouts('layout', include_strided=True)
@dtypes(torch.double)
def test_to_sparse_identity(self, device, layout, dtype):
for dense_dim in range(4):
x_dense = torch.eye(dense_dim, dtype=dtype, device=device)
for sparse_dim_in in range(1, dense_dim):
x_sparse = x_dense.to_sparse(sparse_dim_in)
for sparse_dim_out in range(0, dense_dim):
if sparse_dim_out == sparse_dim_in:
self.assertTrue(x_sparse.to_sparse(sparse_dim_out).sparse_dim() == sparse_dim_out)
else:
with self.assertRaisesRegex(
RuntimeError,
r"to_sparse: conversion from Sparse to Sparse with sparse_dim argument !=self.sparse_dim\(\)"
" is not supported"):
x_sparse.to_sparse(sparse_dim_out)


@onlyNativeDeviceTypes
@suppress_warnings
Expand Down
10 changes: 6 additions & 4 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,7 @@ def is_view_of(base, other):
elif n_batchdim and dim >= n_batchdim and dim < n_batchdim + 2:
with self.assertRaisesRegex(
RuntimeError,
"selecting sparse dimensions is not implemented for batched sparse compressed tensors"):
"selecting sparse dimensions is not supported for batched sparse compressed tensors"):
torch.select_copy(sparse, dim, 0)
else:
for index in {0, sparse.shape[dim] // 2, sparse.shape[dim] - 1}:
Expand Down Expand Up @@ -1043,7 +1043,7 @@ def test_select(self, device, dtype, index_dtype, layout):
sparse[0, 0, 0, 0] = 99.0

# select from sparse dimensions without removing batch dims
msg = "selecting sparse dimensions is not implemented for batched sparse compressed tensors."
msg = "selecting sparse dimensions is not supported for batched sparse compressed tensors."
with self.assertRaisesRegex(RuntimeError, msg):
sparse.select(-2, 0)

Expand Down Expand Up @@ -1368,7 +1368,8 @@ def test_csr_to_block_csr_errors(self, device, dtype):
t = self.genSparseCSRTensor((16, 16), nnz, dtype=dtype,
device=device, index_dtype=index_dtype)

with self.assertRaisesRegex(RuntimeError, r"size \(16, 16\) with block size \(5, 5\)"):
with self.assertRaisesRegex(RuntimeError,
r"tensor sparse size \(.*,.*\) must be divisible by given blocksize \(.*,.*\)"):
block_t = t.to_sparse_bsr((5, 5))

# TODO: Support auto generation of device check for sparse tensors
Expand Down Expand Up @@ -3105,7 +3106,8 @@ def _to_from_layout(layout_a, layout_b, a):
# change of blocksize upon conversion is not yet supported.
if b.layout in block_layouts:
for block_layout in block_layouts:
with self.assertRaisesRegex(RuntimeError, "conversion from.*to.*is not implemented"):
with self.assertRaisesRegex(RuntimeError,
"conversion from.*to.*with blocksize changed from.*to.*is not supported"):
b.to_sparse(layout=block_layout, blocksize=(3, 3))

batch_dims = [(), (2,), (2, 2), (2, 2, 2)]
Expand Down
Loading

0 comments on commit 09fdea8

Please sign in to comment.