Skip to content

Commit

Permalink
Merge pull request scipy#13586 from WarrenWeckesser/coo-default-dtype
Browse files Browse the repository at this point in the history
BUG: sparse: Create a utility function `getdata`
  • Loading branch information
tylerjereddy authored Feb 27, 2021
2 parents 1d296b6 + 99bc23f commit f508446
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 9 deletions.
7 changes: 3 additions & 4 deletions scipy/sparse/bsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from .data import _data_matrix, _minmax_mixin
from .compressed import _cs_matrix
from .base import isspmatrix, _formats, spmatrix
from .sputils import (isshape, getdtype, to_native, upcast, get_index_dtype,
check_shape)
from .sputils import (isshape, getdtype, getdata, to_native, upcast,
get_index_dtype, check_shape)
from . import _sparsetools
from ._sparsetools import (bsr_matvec, bsr_matvecs, csr_matmat_maxnnz,
bsr_matmat, bsr_transpose, bsr_sort_indices,
Expand Down Expand Up @@ -175,8 +175,7 @@ def __init__(self, arg1, shape=None, dtype=None, copy=False, blocksize=None):
check_contents=True)
self.indices = np.array(indices, copy=copy, dtype=idx_dtype)
self.indptr = np.array(indptr, copy=copy, dtype=idx_dtype)
self.data = np.array(data, copy=copy,
dtype=getdtype(dtype, data, float))
self.data = getdata(data, copy=copy, dtype=dtype)
if self.data.ndim != 3:
raise ValueError(
'BSR data must be 3-dimensional, got shape=%s' % (
Expand Down
7 changes: 3 additions & 4 deletions scipy/sparse/coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from .base import isspmatrix, SparseEfficiencyWarning, spmatrix
from .data import _data_matrix, _minmax_mixin
from .sputils import (upcast, upcast_char, to_native, isshape, getdtype,
get_index_dtype, downcast_intp_index, check_shape,
check_reshape_kwargs, matrix)
getdata, get_index_dtype, downcast_intp_index,
check_shape, check_reshape_kwargs, matrix)

import operator

Expand Down Expand Up @@ -155,10 +155,9 @@ def __init__(self, arg1, shape=None, dtype=None, copy=False):
self._shape = check_shape((M, N))

idx_dtype = get_index_dtype(maxval=max(self.shape))
data_dtype = getdtype(dtype, obj, default=float)
self.row = np.array(row, copy=copy, dtype=idx_dtype)
self.col = np.array(col, copy=copy, dtype=idx_dtype)
self.data = np.array(obj, copy=copy, dtype=data_dtype)
self.data = getdata(obj, copy=copy, dtype=dtype)
self.has_canonical_format = False
else:
if isspmatrix(arg1):
Expand Down
14 changes: 13 additions & 1 deletion scipy/sparse/sputils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from scipy._lib._util import prod

__all__ = ['upcast', 'getdtype', 'isscalarlike', 'isintlike',
__all__ = ['upcast', 'getdtype', 'getdata', 'isscalarlike', 'isintlike',
'isshape', 'issequence', 'isdense', 'ismatrix', 'get_sum_dtype']

supported_dtypes = [np.bool_, np.byte, np.ubyte, np.short, np.ushort, np.intc,
Expand Down Expand Up @@ -116,6 +116,18 @@ def getdtype(dtype, a=None, default=None):
return newdtype


def getdata(obj, dtype=None, copy=False):
"""
This is a wrapper of `np.array(obj, dtype=dtype, copy=copy)`
that will generate a warning if the result is an object array.
"""
data = np.array(obj, dtype=dtype, copy=copy)
# Defer to getdtype for checking that the dtype is OK.
# This is called for the validation only; we don't need the return value.
getdtype(data.dtype)
return data


def get_index_dtype(arrays=(), maxval=None, check_contents=False):
"""
Based on input (integer) arrays `a`, determine a suitable index data
Expand Down
13 changes: 13 additions & 0 deletions scipy/sparse/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4145,6 +4145,11 @@ def test_constructor4(self):
with pytest.raises(ValueError, match=r'inconsistent shapes'):
coo_matrix([0, 11, 22, 33], shape=(4, 4))

def test_constructor_data_ij_dtypeNone(self):
data = [1]
coo = coo_matrix((data, ([0], [0])), dtype=None)
assert coo.dtype == np.array(data).dtype

@pytest.mark.xfail(run=False, reason='COO does not have a __getitem__')
def test_iterator(self):
pass
Expand Down Expand Up @@ -4339,6 +4344,14 @@ def test_constructor5(self):
# mismatching blocksize
bsr_matrix((data, indices, indptr), blocksize=(1, 1))

def test_default_dtype(self):
# As a numpy array, `values` has shape (2, 2, 1).
values = [[[1], [1]], [[1], [1]]]
indptr = np.array([0, 2], dtype=np.int32)
indices = np.array([0, 1], dtype=np.int32)
b = bsr_matrix((values, indices, indptr), blocksize=(2, 1))
assert b.dtype == np.array(values).dtype

def test_bsr_tocsr(self):
# check native conversion from BSR to CSR
indptr = array([0, 2, 2, 4])
Expand Down

0 comments on commit f508446

Please sign in to comment.