Skip to content

Commit

Permalink
added more type annotations to ep.* and ep.Tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonas Rauber committed Feb 5, 2020
1 parent 3f920bc commit 5bab21f
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 44 deletions.
49 changes: 30 additions & 19 deletions eagerpy/framework.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import overload, Sequence, Callable, Tuple, Any, Optional, cast
# mypy: disallow_untyped_defs

from typing import overload, Sequence, Callable, Tuple, Any, Optional, cast, Union
from typing_extensions import Literal

from .types import Axes, Shape, ShapeOrScalar
Expand Down Expand Up @@ -100,15 +102,15 @@ def maximum(x: TensorOrScalar, y: TensorOrScalar) -> Tensor:
return x.maximum(y)


def argmin(t: TensorType, axis=None) -> TensorType:
def argmin(t: TensorType, axis: Optional[int] = None) -> TensorType:
return t.argmin(axis=axis)


def argmax(t: TensorType, axis=None) -> TensorType:
def argmax(t: TensorType, axis: Optional[int] = None) -> TensorType:
return t.argmax(axis=axis)


def argsort(t: TensorType, axis=-1) -> TensorType:
def argsort(t: TensorType, axis: Optional[int] = -1) -> TensorType:
return t.argsort(axis=axis)


Expand Down Expand Up @@ -144,7 +146,7 @@ def full_like(t: TensorType, fill_value: float) -> TensorType:
return t.full_like(fill_value)


def onehot_like(t: TensorType, indices, *, value=1) -> TensorType:
def onehot_like(t: TensorType, indices: TensorType, *, value: float = 1) -> TensorType:
return t.onehot_like(indices, value=value)


Expand Down Expand Up @@ -187,7 +189,7 @@ def logical_or(x: TensorOrScalar, y: TensorType) -> TensorType:
...


def logical_or(x: TensorOrScalar, y: TensorOrScalar):
def logical_or(x: TensorOrScalar, y: TensorOrScalar) -> Tensor:
if not isinstance(x, Tensor):
return cast(Tensor, y).logical_or(x)
return x.logical_or(y)
Expand Down Expand Up @@ -221,7 +223,7 @@ def where(condition: TensorType, x: TensorOrScalar, y: TensorOrScalar) -> Tensor
return condition.where(x, y)


def tile(t: TensorType, multiples) -> TensorType:
def tile(t: TensorType, multiples: Union[Tuple[int, ...], Tensor]) -> TensorType:
return t.tile(multiples)


Expand All @@ -246,35 +248,44 @@ def squeeze(t: TensorType, axis: Optional[Axes] = None) -> TensorType:
return t.squeeze(axis=axis)


def expand_dims(t: TensorType, axis: int = None) -> TensorType:
def expand_dims(t: TensorType, axis: int) -> TensorType:
return t.expand_dims(axis=axis)


def full(t: TensorType, shape: ShapeOrScalar, value: float) -> TensorType:
return t.full(shape, value)


def index_update(t: TensorType, indices, values) -> TensorType:
def index_update(t: TensorType, indices: Any, values: TensorOrScalar) -> TensorType:
return t.index_update(indices, values)


def arange(t: TensorType, start: int, stop: int = None, step: int = None) -> TensorType:
def arange(
t: TensorType, start: int, stop: Optional[int] = None, step: Optional[int] = None
) -> TensorType:
return t.arange(start, stop, step)


def cumsum(t: TensorType, axis=None) -> TensorType:
def cumsum(t: TensorType, axis: Optional[int] = None) -> TensorType:
return t.cumsum(axis=axis)


def flip(t: TensorType, axis: Optional[Axes] = None) -> TensorType:
return t.flip(axis=axis)


def meshgrid(t: TensorType, *tensors, indexing="xy") -> Tuple[TensorType, ...]:
def meshgrid(
t: TensorType, *tensors: TensorType, indexing: str = "xy"
) -> Tuple[TensorType, ...]:
return t.meshgrid(*tensors, indexing=indexing)


def pad(t: TensorType, paddings, mode="constant", value=0) -> TensorType:
def pad(
t: TensorType,
paddings: Tuple[Tuple[int, int], ...],
mode: str = "constant",
value: float = 0,
) -> TensorType:
return t.pad(paddings, mode=mode, value=value)


Expand Down Expand Up @@ -304,37 +315,37 @@ def crossentropy(logits: TensorType, labels: TensorType) -> TensorType:

@overload
def value_and_grad_fn(
t: TensorType, f: Callable
t: TensorType, f: Callable[..., TensorType]
) -> Callable[..., Tuple[TensorType, TensorType]]:
...


@overload
def value_and_grad_fn(
t: TensorType, f: Callable, has_aux: Literal[False]
t: TensorType, f: Callable[..., TensorType], has_aux: Literal[False]
) -> Callable[..., Tuple[TensorType, TensorType]]:
...


@overload
def value_and_grad_fn(
t: TensorType, f: Callable, has_aux: Literal[True]
t: TensorType, f: Callable[..., Tuple[TensorType, Any]], has_aux: Literal[True]
) -> Callable[..., Tuple[TensorType, Any, TensorType]]:
...


def value_and_grad_fn(t, f, has_aux=False):
def value_and_grad_fn(t: Any, f: Any, has_aux: bool = False) -> Any:
return t._value_and_grad_fn(f, has_aux=has_aux)


def value_and_grad(
f: Callable, t: TensorType, *args, **kwargs
f: Callable[..., TensorType], t: TensorType, *args: Any, **kwargs: Any
) -> Tuple[TensorType, TensorType]:
return t.value_and_grad(f, *args, **kwargs)


def value_aux_and_grad(
f: Callable, t: TensorType, *args, **kwargs
f: Callable[..., Tuple[TensorType, Any]], t: TensorType, *args: Any, **kwargs: Any
) -> Tuple[TensorType, Any, TensorType]:
return t.value_aux_and_grad(f, *args, **kwargs)

Expand Down
12 changes: 7 additions & 5 deletions eagerpy/tensor/pytorch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, cast, Union, Any, TypeVar, TYPE_CHECKING, Iterable
from typing import Tuple, cast, Union, Any, TypeVar, TYPE_CHECKING, Iterable, Optional
import numpy as np
from importlib import import_module

Expand Down Expand Up @@ -135,13 +135,13 @@ def maximum(self: TensorType, other) -> TensorType:
other = torch.full_like(self.raw, other)
return type(self)(torch.max(self.raw, other))

def argmin(self: TensorType, axis=None) -> TensorType:
def argmin(self: TensorType, axis: Optional[int] = None) -> TensorType:
return type(self)(self.raw.argmin(dim=axis))

def argmax(self: TensorType, axis=None) -> TensorType:
def argmax(self: TensorType, axis: Optional[int] = None) -> TensorType:
return type(self)(self.raw.argmax(dim=axis))

def argsort(self: TensorType, axis=-1) -> TensorType:
def argsort(self: TensorType, axis: Optional[int] = -1) -> TensorType:
return type(self)(self.raw.argsort(dim=axis))

def uniform(self: TensorType, shape, low=0.0, high=1.0) -> TensorType:
Expand Down Expand Up @@ -176,7 +176,9 @@ def zeros_like(self: TensorType) -> TensorType:
def full_like(self: TensorType, fill_value) -> TensorType:
return type(self)(torch.full_like(self.raw, fill_value))

def onehot_like(self: TensorType, indices: TensorType, *, value=1) -> TensorType:
def onehot_like(
self: TensorType, indices: TensorType, *, value: float = 1
) -> TensorType:
if self.ndim != 2:
raise ValueError("onehot_like only supported for 2D tensors")
if indices.ndim != 1:
Expand Down
64 changes: 47 additions & 17 deletions eagerpy/tensor/tensor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# mypy: disallow_untyped_defs

from abc import ABCMeta, abstractmethod
from typing import TypeVar, Callable, Tuple, Any, overload, Sequence, Union, Optional
from typing_extensions import Literal, final
Expand Down Expand Up @@ -235,15 +237,15 @@ def maximum(self: TensorType, other: TensorOrScalar) -> TensorType:
...

@abstractmethod
def argmin(self: TensorType, axis=None) -> TensorType:
def argmin(self: TensorType, axis: Optional[int] = None) -> TensorType:
...

@abstractmethod
def argmax(self: TensorType, axis=None) -> TensorType:
def argmax(self: TensorType, axis: Optional[int] = None) -> TensorType:
...

@abstractmethod
def argsort(self: TensorType, axis=-1) -> TensorType:
def argsort(self: TensorType, axis: Optional[int] = -1) -> TensorType:
...

@abstractmethod
Expand Down Expand Up @@ -279,7 +281,9 @@ def full_like(self: TensorType, fill_value: float) -> TensorType:
...

@abstractmethod
def onehot_like(self: TensorType, indices, *, value=1) -> TensorType:
def onehot_like(
self: TensorType, indices: TensorType, *, value: float = 1
) -> TensorType:
...

@abstractmethod
Expand Down Expand Up @@ -351,7 +355,9 @@ def log1p(self: TensorType) -> TensorType:
...

@abstractmethod
def tile(self: TensorType, multiples) -> TensorType:
def tile(
self: TensorType, multiples: Union[Tuple[int, ...], "Tensor"]
) -> TensorType:
...

@abstractmethod
Expand All @@ -367,37 +373,49 @@ def squeeze(self: TensorType, axis: Optional[Axes] = None) -> TensorType:
...

@abstractmethod
def expand_dims(self: TensorType, axis: int = None) -> TensorType:
def expand_dims(self: TensorType, axis: int) -> TensorType:
...

@abstractmethod
def full(self: TensorType, shape: ShapeOrScalar, value: float) -> TensorType:
...

@abstractmethod
def index_update(self: TensorType, indices, values) -> TensorType:
def index_update(
self: TensorType, indices: Any, values: TensorOrScalar
) -> TensorType:
...

@abstractmethod
def arange(
self: TensorType, start: int, stop: int = None, step: int = None
self: TensorType,
start: int,
stop: Optional[int] = None,
step: Optional[int] = None,
) -> TensorType:
...

@abstractmethod
def cumsum(self: TensorType, axis=None) -> TensorType:
def cumsum(self: TensorType, axis: Optional[int] = None) -> TensorType:
...

@abstractmethod
def flip(self: TensorType, axis: Optional[Axes] = None) -> TensorType:
...

@abstractmethod
def meshgrid(self: TensorType, *tensors, indexing="xy") -> Tuple[TensorType, ...]:
def meshgrid(
self: TensorType, *tensors: TensorType, indexing: str = "xy"
) -> Tuple[TensorType, ...]:
...

@abstractmethod
def pad(self: TensorType, paddings, mode="constant", value=0) -> TensorType:
def pad(
self: TensorType,
paddings: Tuple[Tuple[int, int], ...],
mode: str = "constant",
value: float = 0,
) -> TensorType:
...

@abstractmethod
Expand All @@ -414,24 +432,33 @@ def crossentropy(self: TensorType, labels: TensorType) -> TensorType:

@overload
def _value_and_grad_fn(
self: TensorType, f: Callable
self: TensorType, f: Callable[..., TensorType]
) -> Callable[..., Tuple[TensorType, TensorType]]:
...

@overload # noqa: F811 (waiting for pyflakes > 2.1.1)
def _value_and_grad_fn(
self: TensorType, f: Callable, has_aux: Literal[False]
self: TensorType, f: Callable[..., TensorType], has_aux: Literal[False]
) -> Callable[..., Tuple[TensorType, TensorType]]:
...

@overload # noqa: F811 (waiting for pyflakes > 2.1.1)
def _value_and_grad_fn(
self: TensorType, f: Callable, has_aux: Literal[True]
self: TensorType,
f: Callable[..., Tuple[TensorType, Any]],
has_aux: Literal[True],
) -> Callable[..., Tuple[TensorType, Any, TensorType]]:
...

@abstractmethod # noqa: F811 (waiting for pyflakes > 2.1.1)
def _value_and_grad_fn(self, f, has_aux=False):
def _value_and_grad_fn(
self: TensorType,
f: Union[Callable[..., TensorType], Callable[..., Tuple[TensorType, Any]]],
has_aux: bool = False,
) -> Union[
Callable[..., Tuple[TensorType, TensorType]],
Callable[..., Tuple[TensorType, Any, TensorType]],
]:
...

@abstractmethod
Expand All @@ -457,13 +484,16 @@ def pow(self: TensorType, exponent: float) -> TensorType:

@final
def value_and_grad(
self: TensorType, f, *args, **kwargs
self: TensorType, f: Callable[..., TensorType], *args: Any, **kwargs: Any
) -> Tuple[TensorType, TensorType]:
return self._value_and_grad_fn(f, has_aux=False)(self, *args, **kwargs)

@final
def value_aux_and_grad(
self: TensorType, f, *args, **kwargs
self: TensorType,
f: Callable[..., Tuple[TensorType, Any]],
*args: Any,
**kwargs: Any,
) -> Tuple[TensorType, Any, TensorType]:
return self._value_and_grad_fn(f, has_aux=True)(self, *args, **kwargs)

Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ warn_unreachable = True
ignore_missing_imports = False
disallow_any_unimported = True
disallow_untyped_calls = True
no_implicit_optional = True

[mypy-numpy.*]
ignore_missing_imports = True
Expand Down
Loading

0 comments on commit 5bab21f

Please sign in to comment.