Skip to content

Commit

Permalink
upgraded mypy and renamed np to jnp to follow new jax guidelines
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonas Rauber committed Oct 11, 2021
1 parent 0c5796f commit 15f269c
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 68 deletions.
133 changes: 68 additions & 65 deletions eagerpy/tensor/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Callable,
Type,
)
from collections.abc import Sequence
from typing_extensions import Literal
from importlib import import_module
import numpy as onp
Expand All @@ -28,12 +29,12 @@
if TYPE_CHECKING:
# for static analyzers
import jax
import jax.numpy as np
import jax.numpy as jnp
from .extensions import NormsMethods # noqa: F401
else:
# lazy import in JAXTensor
jax = None
np = None
jnp = None


# stricter TensorType to support additional internal methods
Expand All @@ -51,7 +52,7 @@ def getitem_preprocess(x: Any) -> Any:
if isinstance(x, range):
x = list(x)
if isinstance(x, list):
return np.asarray(x)
return jnp.asarray(x)
elif isinstance(x, Tensor):
return x.raw
else:
Expand All @@ -74,24 +75,24 @@ def __new__(cls: Type["JAXTensor"], *args: Any, **kwargs: Any) -> "JAXTensor":
def flatten(t: JAXTensor) -> Tuple[Any, None]:
return ((t.raw,), None)

def unflatten(aux_data: None, children: Tuple) -> JAXTensor:
def unflatten(aux_data: Any, children: Sequence) -> JAXTensor:
return cls(*children)

jax.tree_util.register_pytree_node(cls, flatten, unflatten)
cls._registered = True
return cast(JAXTensor, super().__new__(cls))

def __init__(self, raw: "np.ndarray"): # type: ignore
def __init__(self, raw: "jnp.ndarray"):
global jax
global np
global jnp
if jax is None:
jax = import_module("jax")
np = import_module("jax.numpy")
jax = import_module("jax") # type: ignore
jnp = import_module("jax.numpy")
super().__init__(raw)

@property
def raw(self) -> "np.ndarray": # type: ignore
return super().raw
def raw(self) -> "jnp.ndarray":
return cast("jnp.ndarray", super().raw)

@classmethod
def _get_subkey(cls) -> Any:
Expand Down Expand Up @@ -121,13 +122,13 @@ def astype(self: TensorType, dtype: Any) -> TensorType:
return type(self)(self.raw.astype(dtype))

def clip(self: TensorType, min_: float, max_: float) -> TensorType:
return type(self)(np.clip(self.raw, min_, max_))
return type(self)(jnp.clip(self.raw, min_, max_))

def square(self: TensorType) -> TensorType:
return type(self)(np.square(self.raw))
return type(self)(jnp.square(self.raw))

def arctanh(self: TensorType) -> TensorType:
return type(self)(np.arctanh(self.raw))
return type(self)(jnp.arctanh(self.raw))

def sum(
self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False
Expand All @@ -142,7 +143,7 @@ def prod(
def mean(
self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False
) -> TensorType:
if self.raw.dtype not in [np.float16, np.float32, np.float64]:
if self.raw.dtype not in [jnp.float16, jnp.float32, jnp.float64]:
raise ValueError(
f"Can only calculate the mean of floating types. Got {self.raw.dtype} instead."
)
Expand All @@ -159,10 +160,10 @@ def max(
return type(self)(self.raw.max(axis=axis, keepdims=keepdims))

def minimum(self: TensorType, other: TensorOrScalar) -> TensorType:
return type(self)(np.minimum(self.raw, unwrap1(other)))
return type(self)(jnp.minimum(self.raw, unwrap1(other)))

def maximum(self: TensorType, other: TensorOrScalar) -> TensorType:
return type(self)(np.maximum(self.raw, unwrap1(other)))
return type(self)(jnp.maximum(self.raw, unwrap1(other)))

def argmin(self: TensorType, axis: Optional[int] = None) -> TensorType:
return type(self)(self.raw.argmin(axis=axis))
Expand All @@ -182,12 +183,12 @@ def topk(
# argpartition not yet implemented
# wrapping indexing not supported in take()
n = self.raw.shape[-1]
idx = np.take(np.argsort(self.raw), np.arange(n - k, n), axis=-1)
val = np.take_along_axis(self.raw, idx, axis=-1)
idx = jnp.take(jnp.argsort(self.raw), jnp.arange(n - k, n), axis=-1)
val = jnp.take_along_axis(self.raw, idx, axis=-1)
if sorted:
perm = np.flip(np.argsort(val, axis=-1), axis=-1)
idx = np.take_along_axis(idx, perm, axis=-1)
val = np.take_along_axis(self.raw, idx, axis=-1)
perm = jnp.flip(jnp.argsort(val, axis=-1), axis=-1)
idx = jnp.take_along_axis(idx, perm, axis=-1)
val = jnp.take_along_axis(self.raw, idx, axis=-1)
return type(self)(val), type(self)(idx)

def uniform(
Expand All @@ -206,22 +207,24 @@ def normal(
shape = (shape,)

subkey = self._get_subkey()
return type(self)(jax.random.normal(subkey, shape) * stddev + mean)
return type(self)(
cast("jnp.ndarray", jax.random.normal(subkey, shape) * stddev + mean)
)

def ones(self: TensorType, shape: ShapeOrScalar) -> TensorType:
return type(self)(np.ones(shape, dtype=self.raw.dtype))
return type(self)(jnp.ones(shape, dtype=self.raw.dtype))

def zeros(self: TensorType, shape: ShapeOrScalar) -> TensorType:
return type(self)(np.zeros(shape, dtype=self.raw.dtype))
return type(self)(jnp.zeros(shape, dtype=self.raw.dtype))

def ones_like(self: TensorType) -> TensorType:
return type(self)(np.ones_like(self.raw))
return type(self)(jnp.ones_like(self.raw))

def zeros_like(self: TensorType) -> TensorType:
return type(self)(np.zeros_like(self.raw))
return type(self)(jnp.zeros_like(self.raw))

def full_like(self: TensorType, fill_value: float) -> TensorType:
return type(self)(np.full_like(self.raw, fill_value))
return type(self)(jnp.full_like(self.raw, fill_value))

def onehot_like(
self: TensorType, indices: TensorType, *, value: float = 1
Expand All @@ -232,31 +235,31 @@ def onehot_like(
raise ValueError("onehot_like requires 1D indices")
if len(indices) != len(self):
raise ValueError("length of indices must match length of tensor")
x = np.arange(self.raw.shape[1]).reshape(1, -1)
x = jnp.arange(self.raw.shape[1]).reshape(1, -1)
indices = indices.raw.reshape(-1, 1)
return type(self)((x == indices) * value)

def from_numpy(self: TensorType, a: Any) -> TensorType:
return type(self)(np.asarray(a))
return type(self)(jnp.asarray(a))

def _concatenate(
self: TensorType, tensors: Iterable[TensorType], axis: int = 0
) -> TensorType:
# concatenates only "tensors", but not "self"
tensors_ = unwrap_(*tensors)
return type(self)(np.concatenate(tensors_, axis=axis))
return type(self)(jnp.concatenate(tensors_, axis=axis))

def _stack(
self: TensorType, tensors: Iterable[TensorType], axis: int = 0
) -> TensorType:
# stacks only "tensors", but not "self"
tensors_ = unwrap_(*tensors)
return type(self)(np.stack(tensors_, axis=axis))
return type(self)(jnp.stack(tensors_, axis=axis))

def transpose(self: TensorType, axes: Optional[Axes] = None) -> TensorType:
if axes is None:
axes = tuple(range(self.ndim - 1, -1, -1))
return type(self)(np.transpose(self.raw, axes=axes))
return type(self)(jnp.transpose(self.raw, axes=axes))

def all(
self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False
Expand All @@ -273,37 +276,37 @@ def any(
def logical_and(self: TensorType, other: TensorOrScalar) -> TensorType:
assert_bool(self)
assert_bool(other)
return type(self)(np.logical_and(self.raw, unwrap1(other)))
return type(self)(jnp.logical_and(self.raw, unwrap1(other)))

def logical_or(self: TensorType, other: TensorOrScalar) -> TensorType:
assert_bool(self)
assert_bool(other)
return type(self)(np.logical_or(self.raw, unwrap1(other)))
return type(self)(jnp.logical_or(self.raw, unwrap1(other)))

def logical_not(self: TensorType) -> TensorType:
assert_bool(self)
return type(self)(np.logical_not(self.raw))
return type(self)(jnp.logical_not(self.raw))

def exp(self: TensorType) -> TensorType:
return type(self)(np.exp(self.raw))
return type(self)(jnp.exp(self.raw))

def log(self: TensorType) -> TensorType:
return type(self)(np.log(self.raw))
return type(self)(jnp.log(self.raw))

def log2(self: TensorType) -> TensorType:
return type(self)(np.log2(self.raw))
return type(self)(jnp.log2(self.raw))

def log10(self: TensorType) -> TensorType:
return type(self)(np.log10(self.raw))
return type(self)(jnp.log10(self.raw))

def log1p(self: TensorType) -> TensorType:
return type(self)(np.log1p(self.raw))
return type(self)(jnp.log1p(self.raw))

def tile(self: TensorType, multiples: Axes) -> TensorType:
multiples = unwrap1(multiples)
if len(multiples) != self.ndim:
raise ValueError("multiples requires one entry for each dimension")
return type(self)(np.tile(self.raw, multiples))
return type(self)(jnp.tile(self.raw, multiples))

def softmax(self: TensorType, axis: int = -1) -> TensorType:
return type(self)(jax.nn.softmax(self.raw, axis=axis))
Expand All @@ -323,12 +326,12 @@ def squeeze(self: TensorType, axis: Optional[AxisAxes] = None) -> TensorType:
return type(self)(self.raw.squeeze(axis=axis))

def expand_dims(self: TensorType, axis: int) -> TensorType:
return type(self)(np.expand_dims(self.raw, axis=axis))
return type(self)(jnp.expand_dims(self.raw, axis=axis))

def full(self: TensorType, shape: ShapeOrScalar, value: float) -> TensorType:
if not isinstance(shape, Iterable):
shape = (shape,)
return type(self)(np.full(shape, value, dtype=self.raw.dtype))
return type(self)(jnp.full(shape, value, dtype=self.raw.dtype))

def index_update(
self: TensorType, indices: Any, values: TensorOrScalar
Expand All @@ -344,19 +347,19 @@ def arange(
stop: Optional[int] = None,
step: Optional[int] = None,
) -> TensorType:
return type(self)(np.arange(start, stop, step))
return type(self)(jnp.arange(start, stop, step))

def cumsum(self: TensorType, axis: Optional[int] = None) -> TensorType:
return type(self)(self.raw.cumsum(axis=axis))

def flip(self: TensorType, axis: Optional[AxisAxes] = None) -> TensorType:
return type(self)(np.flip(self.raw, axis=axis))
return type(self)(jnp.flip(self.raw, axis=axis))

def meshgrid(
self: TensorType, *tensors: TensorType, indexing: str = "xy"
) -> Tuple[TensorType, ...]:
tensors = unwrap_(*tensors)
outputs = np.meshgrid(self.raw, *tensors, indexing=indexing)
outputs = jnp.meshgrid(self.raw, *tensors, indexing=indexing)
return tuple(type(self)(out) for out in outputs)

def pad(
Expand All @@ -381,16 +384,16 @@ def pad(
raise NotImplementedError # pragma: no cover
if mode == "constant":
return type(self)(
np.pad(self.raw, paddings, mode=mode, constant_values=value)
jnp.pad(self.raw, paddings, mode=mode, constant_values=value)
)
else:
return type(self)(np.pad(self.raw, paddings, mode=mode))
return type(self)(jnp.pad(self.raw, paddings, mode=mode))

def isnan(self: TensorType) -> TensorType:
return type(self)(np.isnan(self.raw))
return type(self)(jnp.isnan(self.raw))

def isinf(self: TensorType) -> TensorType:
return type(self)(np.isinf(self.raw))
return type(self)(jnp.isinf(self.raw))

def crossentropy(self: TensorType, labels: TensorType) -> TensorType:
if self.ndim != 2:
Expand All @@ -402,15 +405,15 @@ def crossentropy(self: TensorType, labels: TensorType) -> TensorType:
# otherwise exp(logits) might become too large or too small
logits = self.raw
logits = logits - logits.max(axis=1, keepdims=True)
e = np.exp(logits)
s = np.sum(e, axis=1)
ces = np.log(s) - np.take_along_axis(
logits, labels.raw[:, np.newaxis], axis=1
e = jnp.exp(logits)
s = jnp.sum(e, axis=1)
ces = jnp.log(s) - jnp.take_along_axis(
logits, labels.raw[:, jnp.newaxis], axis=1
).squeeze(axis=1)
return type(self)(ces)

def slogdet(self: TensorType) -> Tuple[TensorType, TensorType]:
sign, logabsdet = np.linalg.slogdet(self.raw)
sign, logabsdet = jnp.linalg.slogdet(self.raw)
return type(self)(sign), type(self)(logabsdet)

@overload
Expand Down Expand Up @@ -480,23 +483,23 @@ def value_and_grad( # type: ignore
return value_and_grad

def sign(self: TensorType) -> TensorType:
return type(self)(np.sign(self.raw))
return type(self)(jnp.sign(self.raw))

def sqrt(self: TensorType) -> TensorType:
return type(self)(np.sqrt(self.raw))
return type(self)(jnp.sqrt(self.raw))

def tanh(self: TensorType) -> TensorType:
return type(self)(np.tanh(self.raw))
return type(self)(jnp.tanh(self.raw))

def float32(self: TensorType) -> TensorType:
return self.astype(np.float32)
return self.astype(jnp.float32)

def float64(self: TensorType) -> TensorType:
return self.astype(np.float32)
return self.astype(jnp.float32)

def where(self: TensorType, x: TensorOrScalar, y: TensorOrScalar) -> TensorType:
x, y = unwrap_(x, y)
return type(self)(np.where(self.raw, x, y))
return type(self)(jnp.where(self.raw, x, y))

def __lt__(self: TensorType, other: TensorOrScalar) -> TensorType:
return type(self)(self.raw.__lt__(unwrap1(other)))
Expand All @@ -505,10 +508,10 @@ def __le__(self: TensorType, other: TensorOrScalar) -> TensorType:
return type(self)(self.raw.__le__(unwrap1(other)))

def __eq__(self: TensorType, other: TensorOrScalar) -> TensorType: # type: ignore
return type(self)(self.raw.__eq__(unwrap1(other)))
return type(self)(cast("jnp.ndarray", self.raw.__eq__(unwrap1(other))))

def __ne__(self: TensorType, other: TensorOrScalar) -> TensorType: # type: ignore
return type(self)(self.raw.__ne__(unwrap1(other)))
return type(self)(cast("jnp.ndarray", self.raw.__ne__(unwrap1(other))))

def __gt__(self: TensorType, other: TensorOrScalar) -> TensorType:
return type(self)(self.raw.__gt__(unwrap1(other)))
Expand All @@ -528,7 +531,7 @@ def take_along_axis(self: TensorType, index: TensorType, axis: int) -> TensorTyp
raise NotImplementedError(
"take_along_axis is currently only supported for the last axis"
)
return type(self)(np.take_along_axis(self.raw, index.raw, axis=axis))
return type(self)(jnp.take_along_axis(self.raw, index.raw, axis=axis))

def bool(self: TensorType) -> TensorType:
return self.astype(np.bool_)
return self.astype(jnp.bool_)
4 changes: 2 additions & 2 deletions eagerpy/tensor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def argsort(self: TensorType, axis: int = -1) -> TensorType:
return type(self)(self.raw.argsort(dim=axis))

def sort(self: TensorType, axis: int = -1) -> TensorType:
return type(self)(self.raw.sort(dim=axis).values) # type: ignore
return type(self)(self.raw.sort(dim=axis).values)

def topk(
self: TensorType, k: int, sorted: bool = True
Expand Down Expand Up @@ -459,7 +459,7 @@ def isnan(self: TensorType) -> TensorType:
return type(self)(torch.isnan(self.raw))

def isinf(self: TensorType) -> TensorType:
return type(self)(torch.isinf(self.raw)) # type: ignore
return type(self)(torch.isinf(self.raw))

def crossentropy(self: TensorType, labels: TensorType) -> TensorType:
if self.ndim != 2:
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ pytest-cov>=2.8.1
coverage>=5.0.3
codecov>=2.0.15
coveralls>=1.10.0
mypy>=0.761
mypy>=0.910
pre-commit>=1.21.0
pydoc-markdown==2.0.5

0 comments on commit 15f269c

Please sign in to comment.