Skip to content

Commit

Permalink
fixed problems introduced through type annotations and added more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonas Rauber committed Jan 30, 2020
1 parent 93f1e9d commit bde85ee
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 20 deletions.
6 changes: 6 additions & 0 deletions eagerpy/tensor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ def wrapper(self, *args, **kwargs):
return wrapper


def unwrap_(*args):
"""Unwraps all EagerPy tensors if they are not already unwrapped"""
result = tuple(t.tensor if istensor(t) else t for t in args)
return result[0] if len(args) == 1 else result


class AbstractBaseTensor(AbstractTensor):
def __init__(self, tensor):
self.tensor = tensor
Expand Down
11 changes: 7 additions & 4 deletions eagerpy/tensor/jax.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .base import AbstractBaseTensor
from .base import unwrapin
from .base import wrapout
from .base import unwrap_

from .tensor import istensor

Expand All @@ -9,6 +10,8 @@


def assert_bool(x):
if not istensor(x):
return
if x.dtype != x.backend.bool_:
raise ValueError(f"all only supports dtype bool, consider t.bool().all()")

Expand Down Expand Up @@ -203,17 +206,17 @@ def any(self, axis=None, keepdims=False):
assert_bool(self)
return self.tensor.any(axis=axis, keepdims=keepdims)

@unwrapin
@wrapout
def logical_and(self, other):
assert_bool(self)
return self.backend.logical_and(self.tensor, other)
assert_bool(other)
return self.backend.logical_and(self.tensor, unwrap_(other))

@unwrapin
@wrapout
def logical_or(self, other):
assert_bool(self)
return self.backend.logical_or(self.tensor, other)
assert_bool(other)
return self.backend.logical_or(self.tensor, unwrap_(other))

@wrapout
def logical_not(self):
Expand Down
11 changes: 7 additions & 4 deletions eagerpy/tensor/numpy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from .base import AbstractBaseTensor
from .base import unwrapin
from .base import wrapout
from .base import unwrap_

from .tensor import istensor


def assert_bool(x):
if not istensor(x):
return
if x.dtype != x.backend.dtype("bool"):
raise ValueError(f"all only supports dtype bool, consider t.bool().all()")

Expand Down Expand Up @@ -162,17 +165,17 @@ def any(self, axis=None, keepdims=False):
assert_bool(self)
return self.tensor.any(axis=axis, keepdims=keepdims)

@unwrapin
@wrapout
def logical_and(self, other):
assert_bool(self)
return self.backend.logical_and(self.tensor, other)
assert_bool(other)
return self.backend.logical_and(self.tensor, unwrap_(other))

@unwrapin
@wrapout
def logical_or(self, other):
assert_bool(self)
return self.backend.logical_or(self.tensor, other)
assert_bool(other)
return self.backend.logical_or(self.tensor, unwrap_(other))

@wrapout
def logical_not(self):
Expand Down
11 changes: 7 additions & 4 deletions eagerpy/tensor/pytorch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .base import AbstractBaseTensor
from .base import wrapout
from .base import unwrapin
from .base import unwrap_

from .tensor import istensor

Expand All @@ -9,6 +10,8 @@


def assert_bool(x):
if not istensor(x):
return
if x.dtype != x.backend.bool:
raise ValueError(f"all only supports dtype bool, consider t.bool().all()")

Expand Down Expand Up @@ -245,17 +248,17 @@ def any(self, axis=None, keepdims=False):
x = x.any(i, keepdim=keepdims)
return x

@unwrapin
@wrapout
def logical_and(self, other):
assert_bool(self)
return self.tensor & other
assert_bool(other)
return self.tensor & unwrap_(other)

@unwrapin
@wrapout
def logical_or(self, other):
assert_bool(self)
return self.tensor | other
assert_bool(other)
return self.tensor | unwrap_(other)

@wrapout
def logical_not(self):
Expand Down
8 changes: 4 additions & 4 deletions eagerpy/tensor/tensor.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from abc import ABC, abstractmethod
from typing import TypeVar, Callable, Tuple, Any, overload, SupportsAbs, Sized, Sequence
from typing import TypeVar, Callable, Tuple, Any, overload, Sequence, cast
from typing_extensions import Literal


Tensor = TypeVar("Tensor", bound="AbstractTensor")


class AbstractTensor(SupportsAbs, Sized, ABC):
class AbstractTensor(ABC):
__array_ufunc__ = None

@property
def T(self: Tensor) -> Tensor:
return self.transpose()

def abs(self: Tensor) -> Tensor:
return self.__abs__()
return cast(Tensor, self.__abs__())

def pow(self: Tensor, exponent) -> Tensor:
return self.__pow__(exponent)
Expand Down Expand Up @@ -53,7 +53,7 @@ def __len__(self: Tensor) -> int:
...

@abstractmethod
def __abs__(self: Tensor) -> Tensor:
def __abs__(self):
...

@abstractmethod
Expand Down
11 changes: 7 additions & 4 deletions eagerpy/tensor/tensorflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .base import AbstractBaseTensor
from .base import unwrapin
from .base import wrapout
from .base import unwrap_

from .tensor import istensor

Expand Down Expand Up @@ -45,6 +46,8 @@ def wrapper(self, *args, **kwargs):


def assert_bool(x):
if not istensor(x):
return
if x.dtype != x.backend.bool:
raise ValueError(f"all only supports dtype bool, consider t.bool().all()")

Expand Down Expand Up @@ -272,17 +275,17 @@ def any(self, axis=None, keepdims=False):
assert_bool(self)
return self.backend.reduce_any(self.tensor, axis=axis, keepdims=keepdims)

@unwrapin
@wrapout
def logical_and(self, other):
assert_bool(self)
return self.backend.logical_and(self.tensor, other)
assert_bool(other)
return self.backend.logical_and(self.tensor, unwrap_(other))

@unwrapin
@wrapout
def logical_or(self, other):
assert_bool(self)
return self.backend.logical_or(self.tensor, other)
assert_bool(other)
return self.backend.logical_or(self.tensor, unwrap_(other))

@wrapout
def logical_not(self):
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ exclude_lines =
# Don't complain if tests don't hit defensive assertion code:
@abstractmethod
@overload
97 changes: 97 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,83 @@ def test_transpose_1d(dummy: Tensor):
assert (ep.transpose(t) == t).all()


def test_onehot_like_raises(dummy: Tensor):
t = ep.arange(dummy, 18).float32().reshape((6, 3))
indices = ep.arange(t, 6) // 2
ep.onehot_like(t, indices)

t = ep.arange(dummy, 90).float32().reshape((6, 3, 5))
indices = ep.arange(t, 6) // 2
with pytest.raises(ValueError):
ep.onehot_like(t, indices)

t = ep.arange(dummy, 18).float32().reshape((6, 3))
indices = ep.arange(t, 6).reshape((6, 1)) // 2
with pytest.raises(ValueError):
ep.onehot_like(t, indices)

t = ep.arange(dummy, 18).float32().reshape((6, 3))
indices = ep.arange(t, 5) // 2
with pytest.raises(ValueError):
ep.onehot_like(t, indices)


def test_tile_raises(t: Tensor):
ep.tile(t, (3,) * t.ndim)
with pytest.raises(ValueError):
ep.tile(t, (3,) * (t.ndim - 1))


def test_pad_raises(dummy: Tensor):
t = ep.arange(dummy, 120).reshape((2, 3, 4, 5)).float32()
ep.pad(t, ((0, 0), (0, 0), (2, 3), (1, 2)), mode="constant")
with pytest.raises(ValueError):
ep.pad(t, ((0, 0), (2, 3), (1, 2)), mode="constant")
with pytest.raises(ValueError):
ep.pad(t, ((0, 0), (0, 0, 1, 2), (2, 3), (1, 2)), mode="constant")
with pytest.raises(ValueError):
ep.pad(t, ((0, 0), (0, 0), (2, 3), (1, 2)), mode="foo")


@pytest.mark.parametrize("f", [ep.logical_and, ep.logical_or])
def test_logical_and_nonboolean(t: Tensor, f):
t = t.float32()
f(t > 1, t > 1)
with pytest.raises(ValueError):
f(t, t > 1)
with pytest.raises(ValueError):
f(t > 1, t)
with pytest.raises(ValueError):
f(t, t)


def test_crossentropy_raises(dummy: Tensor):
t = ep.arange(dummy, 50).reshape((10, 5)).float32()
t = t / t.max()
ep.crossentropy(t, t.argmax(axis=-1))

t = ep.arange(dummy, 150).reshape((10, 5, 3)).float32()
t = t / t.max()
with pytest.raises(ValueError):
ep.crossentropy(t, t.argmax(axis=-1))

t = ep.arange(dummy, 50).reshape((10, 5)).float32()
t = t / t.max()
with pytest.raises(ValueError):
ep.crossentropy(t, t.argmax(axis=-1)[:8])


def test_matmul_raise(dummy: Tensor):
t = ep.arange(dummy, 8).float32().reshape((2, 4))
ep.matmul(t, t.T)
with pytest.raises(ValueError):
ep.matmul(t, t[0])
with pytest.raises(ValueError):
ep.matmul(t[0], t)
with pytest.raises(ValueError):
ep.matmul(t[0], t[0])


###############################################################################
# special tests
# - decorated with compare_*
Expand Down Expand Up @@ -457,6 +534,11 @@ def test_all_keepdims(t: Tensor):
return ep.all(t > 3, axis=0, keepdims=True)


@compare_all
def test_all_none_keepdims(t: Tensor):
return ep.all(t > 3, axis=None, keepdims=True)


@compare_all
def test_any(t: Tensor):
return ep.any(t > 3)
Expand All @@ -472,6 +554,11 @@ def test_any_keepdims(t: Tensor):
return ep.any(t > 3, axis=0, keepdims=True)


@compare_all
def test_any_none_keepdims(t: Tensor):
return ep.any(t > 3, axis=None, keepdims=True)


@compare_all
def test_min(t: Tensor):
return ep.min(t)
Expand All @@ -487,6 +574,11 @@ def test_min_keepdims(t: Tensor):
return ep.min(t, axis=0, keepdims=True)


@compare_all
def test_min_none_keepdims(t: Tensor):
return ep.min(t, axis=None, keepdims=True)


@compare_all
def test_max(t: Tensor):
return ep.max(t)
Expand All @@ -502,6 +594,11 @@ def test_max_keepdims(t: Tensor):
return ep.max(t, axis=0, keepdims=True)


@compare_all
def test_max_none_keepdims(t: Tensor):
return ep.max(t, axis=None, keepdims=True)


@compare_allclose
def test_exp(t: Tensor):
return ep.exp(t)
Expand Down

0 comments on commit bde85ee

Please sign in to comment.