Skip to content

Commit

Permalink
added all methods to AbstractTensor
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonas Rauber committed Jan 29, 2020
1 parent 91a647a commit 7b22e97
Show file tree
Hide file tree
Showing 4 changed files with 406 additions and 33 deletions.
3 changes: 2 additions & 1 deletion eagerpy/tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .base import Tensor # noqa: F401
from .tensor import Tensor # noqa: F401

from .pytorch import PyTorchTensor # noqa: F401
from .tensorflow import TensorFlowTensor # noqa: F401
from .numpy import NumPyTensor # noqa: F401
Expand Down
38 changes: 6 additions & 32 deletions eagerpy/tensor/base.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,7 @@
from abc import ABC
import functools
from typing import TypeVar


class AbstractTensor(ABC):
__array_ufunc__ = None

def __init__(self, tensor):
self.tensor = tensor


Tensor = TypeVar("Tensor", bound=AbstractTensor)


def istensor(x):
return isinstance(x, AbstractTensor)
from .tensor import AbstractTensor
from .tensor import istensor


def wrapout(f):
Expand All @@ -35,7 +22,10 @@ def wrapper(self, *args, **kwargs):
return wrapper


class AbstractBaseTensor(AbstractTensor, ABC):
class AbstractBaseTensor(AbstractTensor):
def __init__(self, tensor):
self.tensor = tensor

def __repr__(self):
lines = self.tensor.__repr__().split("\n")
prefix = self.__class__.__name__ + "("
Expand All @@ -60,9 +50,6 @@ def __getitem__(self, index):
def dtype(self):
return self.tensor.dtype

def abs(self):
return self.__abs__()

def __bool__(self):
return self.tensor.__bool__()

Expand Down Expand Up @@ -166,9 +153,6 @@ def __ge__(self, other):
def __pow__(self, exponent):
return self.tensor.__pow__(exponent)

def pow(self, exponent):
return self.__pow__(exponent)

@wrapout
def sign(self):
return self.backend.sign(self.tensor)
Expand Down Expand Up @@ -199,13 +183,3 @@ def matmul(self, other):
@property
def ndim(self):
return self.tensor.ndim

@property
def T(self):
return self.transpose()

def value_and_grad(self, f, *args, **kwargs):
return self._value_and_grad_fn(f)(self, *args, **kwargs)

def value_aux_and_grad(self, f, *args, **kwargs):
return self._value_and_grad_fn(f, has_aux=True)(self, *args, **kwargs)
Loading

0 comments on commit 7b22e97

Please sign in to comment.