Skip to content

Commit

Permalink
Add support for topk (jonasrauber#24)
Browse files Browse the repository at this point in the history
* Add support for topk

* Fix return type

* reformat with black

* style tweaks: tuple unpacking

* Testcase that is not already sorted

* Use integers for more determinism
  • Loading branch information
mglisse authored Dec 16, 2020
1 parent a63c881 commit 9568566
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/api/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ Tensor.argmax(self:~TensorType, axis:Union[int, NoneType]=None) -> ~TensorType
Tensor.argsort(self:~TensorType, axis:int=-1) -> ~TensorType
```

## topk
```python
Tensor.topk(self:~TensorType, k:int, sorted:bool=True) -> Tuple[~TensorType, ~TensorType]
```

## uniform
```python
Tensor.uniform(self:~TensorType, shape:Union[Tuple[int, ...], int], low:float=0.0, high:float=1.0) -> ~TensorType
Expand Down
4 changes: 4 additions & 0 deletions eagerpy/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ def sort(t: TensorType, axis: int = -1) -> TensorType:
return t.sort(axis=axis)


def topk(t: TensorType, k: int, sorted: bool = True) -> Tuple[TensorType, TensorType]:
return t.topk(k, sorted=sorted)


def uniform(
t: TensorType, shape: ShapeOrScalar, low: float = 0.0, high: float = 1.0
) -> TensorType:
Expand Down
14 changes: 14 additions & 0 deletions eagerpy/tensor/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,20 @@ def argsort(self: TensorType, axis: int = -1) -> TensorType:
def sort(self: TensorType, axis: int = -1) -> TensorType:
return type(self)(self.raw.sort(axis=axis))

def topk(
self: TensorType, k: int, sorted: bool = True
) -> Tuple[TensorType, TensorType]:
# argpartition not yet implemented
# wrapping indexing not supported in take()
n = self.raw.shape[-1]
idx = np.take(np.argsort(self.raw), np.arange(n - k, n), axis=-1)
val = np.take_along_axis(self.raw, idx, axis=-1)
if sorted:
perm = np.flip(np.argsort(val, axis=-1), axis=-1)
idx = np.take_along_axis(idx, perm, axis=-1)
val = np.take_along_axis(self.raw, idx, axis=-1)
return type(self)(val), type(self)(idx)

def uniform(
self: TensorType, shape: ShapeOrScalar, low: float = 0.0, high: float = 1.0
) -> TensorType:
Expand Down
11 changes: 11 additions & 0 deletions eagerpy/tensor/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ def argsort(self: TensorType, axis: int = -1) -> TensorType:
def sort(self: TensorType, axis: int = -1) -> TensorType:
return type(self)(np.sort(self.raw, axis=axis))

def topk(
self: TensorType, k: int, sorted: bool = True
) -> Tuple[TensorType, TensorType]:
idx = np.take(np.argpartition(self.raw, k - 1), np.arange(-k, 0), axis=-1)
val = np.take_along_axis(self.raw, idx, axis=-1)
if sorted:
perm = np.flip(np.argsort(val, axis=-1), axis=-1)
idx = np.take_along_axis(idx, perm, axis=-1)
val = np.take_along_axis(self.raw, idx, axis=-1)
return type(self)(val), type(self)(idx)

def uniform(
self: TensorType, shape: ShapeOrScalar, low: float = 0.0, high: float = 1.0
) -> TensorType:
Expand Down
6 changes: 6 additions & 0 deletions eagerpy/tensor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,12 @@ def argsort(self: TensorType, axis: int = -1) -> TensorType:
def sort(self: TensorType, axis: int = -1) -> TensorType:
return type(self)(self.raw.sort(dim=axis).values) # type: ignore

def topk(
self: TensorType, k: int, sorted: bool = True
) -> Tuple[TensorType, TensorType]:
values, indices = self.raw.topk(k, sorted=sorted)
return type(self)(values), type(self)(indices)

def uniform(
self: TensorType, shape: ShapeOrScalar, low: float = 0.0, high: float = 1.0
) -> TensorType:
Expand Down
6 changes: 6 additions & 0 deletions eagerpy/tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,12 @@ def argsort(self: TensorType, axis: int = -1) -> TensorType:
def sort(self: TensorType, axis: int = -1) -> TensorType:
...

@abstractmethod
def topk(
self: TensorType, k: int, sorted: bool = True
) -> Tuple[TensorType, TensorType]:
...

@abstractmethod
def uniform(
self: TensorType, shape: ShapeOrScalar, low: float = 0.0, high: float = 1.0
Expand Down
6 changes: 6 additions & 0 deletions eagerpy/tensor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ def argsort(self: TensorType, axis: Optional[int] = -1) -> TensorType:
def sort(self: TensorType, axis: Optional[int] = -1) -> TensorType:
return type(self)(tf.sort(self.raw, axis=axis))

def topk(
self: TensorType, k: int, sorted: bool = True
) -> Tuple[TensorType, TensorType]:
values, indices = tf.math.top_k(self.raw, k, sorted=sorted)
return type(self)(values), type(self)(indices)

@samedevice
def uniform(
self: TensorType, shape: ShapeOrScalar, low: float = 0.0, high: float = 1.0
Expand Down
14 changes: 14 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1206,6 +1206,20 @@ def test_sort(dummy: Tensor) -> Tensor:
return ep.sort(t)


@compare_all
def test_topk_values(dummy: Tensor) -> Tensor:
t = (ep.arange(dummy, 27).reshape((3, 3, 3)) ** 2 * 10000 % 1234).float32()
values, _ = ep.topk(t, 2)
return values


@compare_all
def test_topk_indices(dummy: Tensor) -> Tensor:
t = -(ep.arange(dummy, 27).reshape((3, 3, 3)) ** 2 * 10000 % 1234).float32()
_, indices = ep.topk(t, 2)
return indices


@compare_all
def test_transpose(dummy: Tensor) -> Tensor:
t = ep.arange(dummy, 8).float32().reshape((2, 4))
Expand Down

0 comments on commit 9568566

Please sign in to comment.