Skip to content

Commit

Permalink
completed the type annotations and added disallow_untyped_defs
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonas Rauber committed Feb 5, 2020
1 parent 5bab21f commit 6d459f3
Show file tree
Hide file tree
Showing 19 changed files with 831 additions and 583 deletions.
2 changes: 0 additions & 2 deletions eagerpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# mypy: disallow_untyped_defs

from typing import TypeVar
from os.path import join as _join
from os.path import dirname as _dirname
Expand Down
2 changes: 0 additions & 2 deletions eagerpy/astensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# mypy: disallow_untyped_defs

from typing import TYPE_CHECKING, Union, overload, Tuple, TypeVar, Generic, Any
import sys

Expand Down
26 changes: 12 additions & 14 deletions eagerpy/framework.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# mypy: disallow_untyped_defs

from typing import overload, Sequence, Callable, Tuple, Any, Optional, cast, Union
from typing import overload, Sequence, Callable, Tuple, Any, Optional, cast
from typing_extensions import Literal

from .types import Axes, Shape, ShapeOrScalar
from .types import Axes, AxisAxes, Shape, ShapeOrScalar

from .tensor import Tensor
from .tensor import TensorType
Expand Down Expand Up @@ -47,25 +45,25 @@ def arctanh(t: TensorType) -> TensorType:


def sum(
t: TensorType, axis: Optional[Axes] = None, keepdims: bool = False
t: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False
) -> TensorType:
return t.sum(axis=axis, keepdims=keepdims)


def mean(
t: TensorType, axis: Optional[Axes] = None, keepdims: bool = False
t: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False
) -> TensorType:
return t.mean(axis=axis, keepdims=keepdims)


def min(
t: TensorType, axis: Optional[Axes] = None, keepdims: bool = False
t: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False
) -> TensorType:
return t.min(axis=axis, keepdims=keepdims)


def max(
t: TensorType, axis: Optional[Axes] = None, keepdims: bool = False
t: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False
) -> TensorType:
return t.max(axis=axis, keepdims=keepdims)

Expand Down Expand Up @@ -110,7 +108,7 @@ def argmax(t: TensorType, axis: Optional[int] = None) -> TensorType:
return t.argmax(axis=axis)


def argsort(t: TensorType, axis: Optional[int] = -1) -> TensorType:
def argsort(t: TensorType, axis: int = -1) -> TensorType:
return t.argsort(axis=axis)


Expand Down Expand Up @@ -223,7 +221,7 @@ def where(condition: TensorType, x: TensorOrScalar, y: TensorOrScalar) -> Tensor
return condition.where(x, y)


def tile(t: TensorType, multiples: Union[Tuple[int, ...], Tensor]) -> TensorType:
def tile(t: TensorType, multiples: Axes) -> TensorType:
return t.tile(multiples)


Expand All @@ -244,7 +242,7 @@ def stack(tensors: Sequence[TensorType], axis: int = 0) -> TensorType:
return t._stack(tensors, axis=axis)


def squeeze(t: TensorType, axis: Optional[Axes] = None) -> TensorType:
def squeeze(t: TensorType, axis: Optional[AxisAxes] = None) -> TensorType:
return t.squeeze(axis=axis)


Expand All @@ -270,7 +268,7 @@ def cumsum(t: TensorType, axis: Optional[int] = None) -> TensorType:
return t.cumsum(axis=axis)


def flip(t: TensorType, axis: Optional[Axes] = None) -> TensorType:
def flip(t: TensorType, axis: Optional[AxisAxes] = None) -> TensorType:
return t.flip(axis=axis)


Expand Down Expand Up @@ -298,13 +296,13 @@ def isinf(t: TensorType) -> TensorType:


def all(
t: TensorType, axis: Optional[Axes] = None, keepdims: bool = False
t: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False
) -> TensorType:
return t.all(axis=axis, keepdims=keepdims)


def any(
t: TensorType, axis: Optional[Axes] = None, keepdims: bool = False
t: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False
) -> TensorType:
return t.any(axis=axis, keepdims=keepdims)

Expand Down
2 changes: 0 additions & 2 deletions eagerpy/lib.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# mypy: disallow_untyped_defs

from .tensor import TensorType


Expand Down
2 changes: 0 additions & 2 deletions eagerpy/modules.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# mypy: disallow_untyped_defs

from importlib import import_module
import inspect
from types import ModuleType
Expand Down
14 changes: 6 additions & 8 deletions eagerpy/norms.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,38 @@
# mypy: disallow_untyped_defs

from typing import Union, Optional

from .tensor import TensorType
from .types import Axes
from .types import AxisAxes
from .framework import inf


def l0(
x: TensorType, axis: Optional[Axes] = None, keepdims: bool = False
x: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False
) -> TensorType:
return (x != 0).sum(axis=axis, keepdims=keepdims)


def l1(
x: TensorType, axis: Optional[Axes] = None, keepdims: bool = False
x: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False
) -> TensorType:
return x.abs().sum(axis=axis, keepdims=keepdims)


def l2(
x: TensorType, axis: Optional[Axes] = None, keepdims: bool = False
x: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False
) -> TensorType:
return x.square().sum(axis=axis, keepdims=keepdims).sqrt()


def linf(
x: TensorType, axis: Optional[Axes] = None, keepdims: bool = False
x: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False
) -> TensorType:
return x.abs().max(axis=axis, keepdims=keepdims)


def lp(
x: TensorType,
p: Union[int, float],
axis: Optional[Axes] = None,
axis: Optional[AxisAxes] = None,
keepdims: bool = False,
) -> TensorType:
if p == 0:
Expand Down
2 changes: 0 additions & 2 deletions eagerpy/tensor/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# mypy: disallow_untyped_defs

from typing_extensions import final
from typing import Any, cast

Expand Down
Loading

0 comments on commit 6d459f3

Please sign in to comment.