Skip to content

Commit

Permalink
added disallow_untyped_calls to mypy config
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonas Rauber committed Feb 1, 2020
1 parent a6e8e38 commit e046157
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 9 deletions.
6 changes: 3 additions & 3 deletions eagerpy/modules.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from importlib import import_module
import inspect
from types import ModuleType
from typing import Any
from typing import Any, Callable
import functools

from .astensor import astensor


def wrap(f):
def wrap(f: Callable) -> Callable:
@functools.wraps(f)
def wrapper(*args, **kwargs):
result = f(*args, **kwargs)
Expand All @@ -24,7 +24,7 @@ class ModuleWrapper(ModuleType):
"""A wrapper for modules that delays the import until it is needed
and wraps the output of functions as EagerPy tensors"""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
if self.__doc__ is None:
self.__doc__ = f"EagerPy wrapper of the '{self.__name__}' module"
Expand Down
2 changes: 1 addition & 1 deletion eagerpy/tensor/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def raw(self) -> "np.ndarray": # type: ignore
return super().raw

@classmethod
def _get_subkey(cls):
def _get_subkey(cls) -> Any:
if cls.key is None:
cls.key = jax.random.PRNGKey(0)
cls.key, subkey = jax.random.split(cls.key)
Expand Down
6 changes: 3 additions & 3 deletions eagerpy/tensor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,9 @@ def flip(self: TensorType, axis=None) -> TensorType:
def meshgrid(self: TensorType, *tensors, indexing="xy") -> Tuple[TensorType, ...]:
tensors = unwrap_(tensors)
if indexing == "ij" or len(tensors) == 0:
outputs = torch.meshgrid(self.raw, *tensors)
outputs = torch.meshgrid(self.raw, *tensors) # type: ignore
elif indexing == "xy":
outputs = torch.meshgrid(tensors[0], self.raw, *tensors[1:])
outputs = torch.meshgrid(tensors[0], self.raw, *tensors[1:]) # type: ignore
else:
raise ValueError( # pragma: no cover
f"Valid values for indexing are 'xy' and 'ij', got {indexing}"
Expand Down Expand Up @@ -365,7 +365,7 @@ def isnan(self: TensorType) -> TensorType:
return type(self)(torch.isnan(self.raw))

def isinf(self: TensorType) -> TensorType:
return type(self)(torch.isinf(self.raw))
return type(self)(torch.isinf(self.raw)) # type: ignore

def crossentropy(self: TensorType, labels: TensorType) -> TensorType:
if self.ndim != 2:
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ warn_redundant_casts = True
warn_unreachable = True
ignore_missing_imports = False
disallow_any_unimported = True
disallow_untyped_calls = True

[mypy-numpy.*]
ignore_missing_imports = True
Expand Down
5 changes: 3 additions & 2 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import eagerpy as ep
from eagerpy import Tensor
from typing import Callable


# make sure there are no undecorated tests in the "special tests" section below
Expand Down Expand Up @@ -344,10 +345,10 @@ def test_fn(*args, **kwargs):
return test_fn


def compare_allclose(*args, rtol=1e-07, atol=0):
def compare_allclose(*args, rtol: float = 1e-07, atol: float = 0):
"""A decorator to simplify writing test functions"""

def compare_allclose_inner(f):
def compare_allclose_inner(f: Callable) -> Callable:
@functools.wraps(f)
def test_fn(*args, **kwargs):
assert len(args) == 0
Expand Down

0 comments on commit e046157

Please sign in to comment.