-
-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for topk #24
Conversation
Codecov Report
@@ Coverage Diff @@
## master #24 +/- ##
=========================================
Coverage 100.00% 100.00%
=========================================
Files 16 16
Lines 1714 1739 +25
=========================================
+ Hits 1714 1739 +25
Continue to review full report at Codecov.
|
Thanks, this is great! I will add some minor change requests, looking forward to merging it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you make the stylistic changes I commented on (and add the better test)? Besides that, it looks good to me 👍
eagerpy/tensor/tensorflow.py
Outdated
def topk( | ||
self: TensorType, k: int, sorted: bool = True | ||
) -> Tuple[TensorType, TensorType]: | ||
pair = tf.math.top_k(self.raw, k, sorted=sorted) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use tuple unpacking instead: values, indices = tf.math…
.
It's more descriptive and avoids tuple indexing in the next line.
eagerpy/tensor/pytorch.py
Outdated
def topk( | ||
self: TensorType, k: int, sorted: bool = True | ||
) -> Tuple[TensorType, TensorType]: | ||
pair = self.raw.topk(k, sorted=sorted) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use tuple unpacking instead: values, indices = self.raw…
.
It's more descriptive and avoids tuple indexing in the next line.
tests/test_main.py
Outdated
@@ -1206,6 +1206,18 @@ def test_sort(dummy: Tensor) -> Tensor: | |||
return ep.sort(t) | |||
|
|||
|
|||
@compare_all | |||
def test_topk_0(dummy: Tensor) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you rename the tests to end with _values
and _indices
instead of 0
and 1
.
tests/test_main.py
Outdated
@@ -1206,6 +1206,18 @@ def test_sort(dummy: Tensor) -> Tensor: | |||
return ep.sort(t) | |||
|
|||
|
|||
@compare_all | |||
def test_topk_0(dummy: Tensor) -> Tensor: | |||
t = ep.arange(dummy, 27).float32().reshape((3, 3, 3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be good to test on something that's not already sorted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had copied from the tests for sort
and argsort
where one is already sorted and the other reverse sorted, so I tested values on sorted and indices on reverse sorted. But I've changed it to something more complicated.
tests/test_main.py
Outdated
@compare_all | ||
def test_topk_0(dummy: Tensor) -> Tensor: | ||
t = ep.arange(dummy, 27).float32().reshape((3, 3, 3)) | ||
return ep.topk(t, 2)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe you can also use tuple unpacking here:
values, _ = ep.topk(t, 2)
and return values
(Same for the other test)
Hello, are there still changes needed? |
@mglisse Thanks, great work. Merged. |
Fix #23.
This adds a minimal version of topk. In the future, it would be good to add some version of the
dim
andlargest
arguments fromtorch.topk
(I use them), but I wanted to start with the basic functionality. The numpy version is a bit more complicated than one might hope, and jax complicates it further by not supportingargpartition
and silently behaving differently intake
.I wasn't sure about the documentation, I searched for "sort" in docs/ and only found argsort in api/tensor.md, so I added topk there...