Skip to content

Commit

Permalink
improved tensorflow's getitem handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonas Rauber committed Aug 14, 2020
1 parent 947450e commit ac9b7f1
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
24 changes: 21 additions & 3 deletions eagerpy/tensor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,9 +535,27 @@ def __getitem__(self: TensorType, index: Any) -> TensorType:
)
if not basic:
# workaround for missing support for this in TensorFlow
# TODO: maybe convert each index individually and then stack them instead
index = tf.convert_to_tensor(index)
index = tf.transpose(index)
index = [tf.convert_to_tensor(x) for x in index]
shapes = [tuple(x.shape) for x in index]
shape = tuple(max(x) for x in zip(*shapes))
int64 = any(x.dtype == tf.int64 for x in index)
for i in range(len(index)):
t = index[i]
if int64:
t = tf.cast(t, tf.int64)
assert t.ndim == len(shape)
tiling = []
for b, k in zip(shape, t.shape):
if k == 1:
tiling.append(b)
elif k == b:
tiling.append(1)
else:
raise ValueError(
f"{tuple(t.shape)} cannot be broadcasted to {shape}"
)
index[i] = tf.tile(t, tiling)
index = tf.stack(index, axis=-1)
return type(self)(tf.gather_nd(self.raw, index))
elif (
isinstance(index, range)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,22 @@ def test_getitem_tuple_tensors(dummy: Tensor) -> Tensor:
return t[rows, indices]


@compare_all
def test_getitem_tuple_tensors_full(dummy: Tensor) -> Tensor:
t = ep.arange(dummy, 32).float32().reshape((8, 4))
rows = ep.arange(t, len(t))[:, np.newaxis].tile((1, t.shape[-1]))
cols = t.argsort(axis=-1)
return t[rows, cols]


@compare_all
def test_getitem_tuple_tensors_full_broadcast(dummy: Tensor) -> Tensor:
t = ep.arange(dummy, 32).float32().reshape((8, 4))
rows = ep.arange(t, len(t))[:, np.newaxis]
cols = t.argsort(axis=-1)
return t[rows, cols]


@compare_all
def test_getitem_tuple_range_tensor(dummy: Tensor) -> Tensor:
t = ep.arange(dummy, 32).float32().reshape((8, 4))
Expand Down

0 comments on commit ac9b7f1

Please sign in to comment.