Skip to content

Commit

Permalink
where() infer type of the other condition when one of the two is a te… (
Browse files Browse the repository at this point in the history
jonasrauber#39)

* where() infer type of the other condition when one of the two is a tensor.

* correct flake8
  • Loading branch information
eserie authored Jun 4, 2021
1 parent 6fa4c79 commit 5f7a0d9
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 2 deletions.
3 changes: 3 additions & 0 deletions eagerpy/tensor/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,9 @@ def tanh(self: TensorType) -> TensorType:
def float32(self: TensorType) -> TensorType:
return self.astype(np.float32)

def float64(self: TensorType) -> TensorType:
return self.astype(np.float32)

def where(self: TensorType, x: TensorOrScalar, y: TensorOrScalar) -> TensorType:
x, y = unwrap_(x, y)
return type(self)(np.where(self.raw, x, y))
Expand Down
3 changes: 3 additions & 0 deletions eagerpy/tensor/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,9 @@ def tanh(self: TensorType) -> TensorType:
def float32(self: TensorType) -> TensorType:
return self.astype(np.float32)

def float64(self: TensorType) -> TensorType:
return self.astype(np.float64)

def where(self: TensorType, x: TensorOrScalar, y: TensorOrScalar) -> TensorType:
x, y = unwrap_(x, y)
return type(self)(np.where(self.raw, x, y))
Expand Down
16 changes: 14 additions & 2 deletions eagerpy/tensor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,19 +535,31 @@ def sqrt(self: TensorType) -> TensorType:
def float32(self: TensorType) -> TensorType:
return self.astype(torch.float32)

def float64(self: TensorType) -> TensorType:
return self.astype(torch.float64)

def where(self: TensorType, x: TensorOrScalar, y: TensorOrScalar) -> TensorType:

if isinstance(x, Tensor):
x_ = x.raw
elif isinstance(x, int) or isinstance(x, float):
x_ = torch.full_like(self.raw, x, dtype=torch.float32)
if isinstance(y, Tensor):
dtype = y.raw.dtype
else:
dtype = torch.float32
x_ = torch.full_like(self.raw, x, dtype=dtype)
else:
raise TypeError(
"expected x to be a Tensor, int or float"
) # pragma: no cover
if isinstance(y, Tensor):
y_ = y.raw
elif isinstance(y, int) or isinstance(y, float):
y_ = torch.full_like(self.raw, y, dtype=torch.float32)
if isinstance(x, Tensor):
dtype = x.raw.dtype
else:
dtype = torch.float32
y_ = torch.full_like(self.raw, y, dtype=dtype)
return type(self)(torch.where(self.raw, x_, y_))

def __lt__(self: TensorType, other: TensorOrScalar) -> TensorType:
Expand Down
4 changes: 4 additions & 0 deletions eagerpy/tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ def tanh(self: TensorType) -> TensorType:
def float32(self: TensorType) -> TensorType:
...

@abstractmethod
def float64(self: TensorType) -> TensorType:
...

@abstractmethod
def where(self: TensorType, x: TensorOrScalar, y: TensorOrScalar) -> TensorType:
...
Expand Down
3 changes: 3 additions & 0 deletions eagerpy/tensor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,9 @@ def tanh(self: TensorType) -> TensorType:
def float32(self: TensorType) -> TensorType:
return self.astype(tf.float32)

def float64(self: TensorType) -> TensorType:
return self.astype(tf.float64)

def where(self: TensorType, x: TensorOrScalar, y: TensorOrScalar) -> TensorType:
x, y = unwrap_(x, y)
return type(self)(tf.where(self.raw, x, y))
Expand Down
12 changes: 12 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,18 @@ def test_where_second_scalar(t: Tensor) -> Tensor:
return ep.where(t >= 3, t, 2)


@compare_all
def test_where_first_scalar64(dummy: Tensor) -> Tensor:
t = ep.arange(dummy, 60).float64().reshape((3, 4, 5))
return ep.where(t >= 3, 2, -t)


@compare_all
def test_where_second_scalar64(dummy: Tensor) -> Tensor:
t = ep.arange(dummy, 60).float64().reshape((3, 4, 5))
return ep.where(t >= 3, t, 2)


@compare_all
def test_where_both_scalars(t: Tensor) -> Tensor:
return ep.where(t >= 3, 2, 5)
Expand Down

0 comments on commit 5f7a0d9

Please sign in to comment.