Skip to content

Commit 1eb4e52

Browse files
anandoleecopybara-github
authored andcommitted
Python Proto scalar repeated numpy binding.
Add __array__ magic method to scalar repeated field. The default returned dtype of numpy/tensorflow usages like np.arrary() np.asarray() will be change. Cpp_Type | np.array() dtype | Old tf.reshape() dtype | New dtype float. | float64/float64 | float32/float32. | float32 double. | float64/float64. | float32/float32. | float64 int32. | int64/float64. | int32/float32. | int32 int64. | int64/float64. | int32/float32. | int64 uint32. | int64/float64. | int32/float32. | uint32 uint64. | int64/float64. | int32/float32. | uint64 bool. | bool/float64. | bool/float32. | bool enum | int64/float64 | int32/float32 | int32 string | U/float64 | string/float32 | U (Unicode string) bytes | S/float64 | string/float32. | S (Byte string) Performance improvement: For non-string repeated: -upb/cpp are about 50-100 times faster. -pure python is ~3x times faster. For string repeated: numbers are noise, no regression (maybe a slight win) PiperOrigin-RevId: 836264273
1 parent 787a221 commit 1eb4e52

File tree

5 files changed

+822
-4
lines changed

5 files changed

+822
-4
lines changed

python/google/protobuf/internal/containers.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
_K = TypeVar('_K')
4141
_V = TypeVar('_V')
4242

43+
from google.protobuf.descriptor import FieldDescriptor
4344

4445
class BaseContainer(Sequence[_T]):
4546
"""Base container class."""
@@ -104,12 +105,13 @@ class RepeatedScalarFieldContainer(BaseContainer[_T], MutableSequence[_T]):
104105
"""Simple, type-checked, list-like container for holding repeated scalars."""
105106

106107
# Disallows assignment to other attributes.
107-
__slots__ = ['_type_checker']
108+
__slots__ = ['_type_checker', '_field']
108109

109110
def __init__(
110111
self,
111112
message_listener: Any,
112113
type_checker: Any,
114+
field: Any = None,
113115
) -> None:
114116
"""Args:
115117
@@ -121,6 +123,7 @@ def __init__(
121123
"""
122124
super().__init__(message_listener)
123125
self._type_checker = type_checker
126+
self._field = field
124127

125128
def append(self, value: _T) -> None:
126129
"""Appends an item to the list. Similar to list.append()."""
@@ -202,14 +205,47 @@ def __deepcopy__(
202205
unused_memo: Any = None,
203206
) -> 'RepeatedScalarFieldContainer[_T]':
204207
clone = RepeatedScalarFieldContainer(
205-
copy.deepcopy(self._message_listener), self._type_checker)
208+
copy.deepcopy(self._message_listener), self._type_checker, self._field
209+
)
206210
clone.MergeFrom(self)
207211
return clone
208212

209213
def __reduce__(self, **kwargs) -> NoReturn:
210214
raise pickle.PickleError(
211215
"Can't pickle repeated scalar fields, convert to list first")
212216

217+
def __array__(self, dtype=None, copy=None):
218+
import numpy as np
219+
220+
if dtype is None:
221+
cpp_type = self._field.cpp_type
222+
if cpp_type == FieldDescriptor.CPPTYPE_INT32:
223+
dtype = np.int32
224+
elif cpp_type == FieldDescriptor.CPPTYPE_INT64:
225+
dtype = np.int64
226+
elif cpp_type == FieldDescriptor.CPPTYPE_UINT32:
227+
dtype = np.uint32
228+
elif cpp_type == FieldDescriptor.CPPTYPE_UINT64:
229+
dtype = np.uint64
230+
elif cpp_type == FieldDescriptor.CPPTYPE_DOUBLE:
231+
dtype = np.float64
232+
elif cpp_type == FieldDescriptor.CPPTYPE_FLOAT:
233+
dtype = np.float32
234+
elif cpp_type == FieldDescriptor.CPPTYPE_BOOL:
235+
dtype = np.bool
236+
elif cpp_type == FieldDescriptor.CPPTYPE_ENUM:
237+
dtype = np.int32
238+
elif self._field.type == FieldDescriptor.TYPE_BYTES:
239+
dtype = 'S'
240+
elif self._field.type == FieldDescriptor.TYPE_STRING:
241+
dtype = str
242+
else:
243+
raise SystemError(
244+
'Code should never reach here: message type detected in'
245+
' RepeatedScalarFieldContainer'
246+
)
247+
return np.array(self._values, dtype=dtype, copy=True)
248+
213249

214250
# TODO: Constrain T to be a subtype of Message.
215251
class RepeatedCompositeFieldContainer(BaseContainer[_T], MutableSequence[_T]):

0 commit comments

Comments
 (0)