Skip to content

Commit

Permalink
added mypy config to setup.cfg
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonas Rauber committed Feb 1, 2020
1 parent 1e285ea commit d01b735
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ flake8:

.PHONY: mypy
mypy:
mypy -p eagerpy --python-version 3.6 --ignore-missing-imports --warn-unused-ignores --warn-unused-configs --warn-return-any --warn-redundant-casts
mypy -p eagerpy

.PHONY: install
install:
Expand Down
6 changes: 3 additions & 3 deletions eagerpy/astensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def astensor(x: "torch.Tensor") -> PyTorchTensor:


@overload
def astensor(x: NativeTensor) -> Tensor:
def astensor(x: NativeTensor) -> Tensor: # type: ignore
...


def astensor(x: Union[NativeTensor, Tensor]) -> Tensor:
def astensor(x: Union[NativeTensor, Tensor]) -> Tensor: # type: ignore
if isinstance(x, Tensor):
return x
# we use the module name instead of isinstance
Expand All @@ -55,7 +55,7 @@ def astensor(x: Union[NativeTensor, Tensor]) -> Tensor:
raise ValueError(f"Unknown type: {type(x)}")


def astensors(*xs: Union[NativeTensor, Tensor]) -> Tuple[Tensor, ...]:
def astensors(*xs: Union[NativeTensor, Tensor]) -> Tuple[Tensor, ...]: # type: ignore
return tuple(astensor(x) for x in xs)


Expand Down
2 changes: 1 addition & 1 deletion eagerpy/tensor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def unwrap_(*args) -> Any:


class BaseTensor(Tensor):
def __init__(self: TensorType, raw):
def __init__(self: TensorType, raw) -> None:
self._raw = raw

@property
Expand Down
4 changes: 2 additions & 2 deletions eagerpy/tensor/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def unflatten(aux_data, children):
cls._registered = True
return super().__new__(cls)

def __init__(self, raw: "np.ndarray"):
def __init__(self, raw: "np.ndarray"): # type: ignore
global jax
global np
if jax is None:
Expand All @@ -57,7 +57,7 @@ def __init__(self, raw: "np.ndarray"):
super().__init__(raw)

@property
def raw(self) -> "np.ndarray":
def raw(self) -> "np.ndarray": # type: ignore
return super().raw

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions eagerpy/tensor/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ def assert_bool(x: TensorType) -> None:


class NumPyTensor(BaseTensor):
def __init__(self, raw: "np.ndarray"):
def __init__(self, raw: "np.ndarray"): # type: ignore
super().__init__(raw)

@property
def raw(self) -> "np.ndarray":
def raw(self) -> "np.ndarray": # type: ignore
return super().raw

def numpy(self: TensorType) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion eagerpy/tensor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class PyTorchTensor(BaseTensor):
def __init__(self, raw: "torch.Tensor"):
global torch
if torch is None:
torch = import_module("torch")
torch = import_module("torch") # type: ignore
super().__init__(raw)

@property
Expand Down
2 changes: 1 addition & 1 deletion eagerpy/tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class Tensor(ABC):
__module__ = "eagerpy"

@abstractmethod
def __init__(self, raw):
def __init__(self, raw) -> None:
...

@property
Expand Down
4 changes: 2 additions & 2 deletions eagerpy/tensor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ def assert_bool(x: TensorType) -> None:


class TensorFlowTensor(BaseTensor):
def __init__(self, raw: "tf.Tensor"):
def __init__(self, raw: "tf.Tensor"): # type: ignore
global tf
if tf is None:
tf = import_module("tensorflow")
super().__init__(raw)

@property
def raw(self) -> "tf.Tensor":
def raw(self) -> "tf.Tensor": # type: ignore
return super().raw

def numpy(self: TensorType) -> Any:
Expand Down
22 changes: 22 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,28 @@ max-line-length = 80
max-complexity = 18
select = B,C,E,F,W,T4,B9

[mypy]
python_version = 3.6
warn_unused_ignores = True
warn_unused_configs = True
warn_return_any = True
warn_redundant_casts = True
warn_unreachable = True
ignore_missing_imports = False
disallow_any_unimported = True

[mypy-numpy]
ignore_missing_imports = True

[mypy-jax]
ignore_missing_imports = True

[mypy-jax.numpy]
ignore_missing_imports = True

[mypy-tensorflow]
ignore_missing_imports = True

[tool:pytest]
filterwarnings =
ignore::DeprecationWarning
Expand Down

0 comments on commit d01b735

Please sign in to comment.