Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for topk #24

Merged
merged 6 commits into from
Dec 16, 2020
Merged

Add support for topk #24

merged 6 commits into from
Dec 16, 2020

Conversation

mglisse
Copy link
Contributor

@mglisse mglisse commented Nov 4, 2020

Fix #23.
This adds a minimal version of topk. In the future, it would be good to add some version of the dim and largest arguments from torch.topk (I use them), but I wanted to start with the basic functionality. The numpy version is a bit more complicated than one might hope, and jax complicates it further by not supporting argpartition and silently behaving differently in take.
I wasn't sure about the documentation, I searched for "sort" in docs/ and only found argsort in api/tensor.md, so I added topk there...

@codecov
Copy link

codecov bot commented Nov 4, 2020

Codecov Report

Merging #24 into master will not change coverage.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff            @@
##            master       #24   +/-   ##
=========================================
  Coverage   100.00%   100.00%           
=========================================
  Files           16        16           
  Lines         1714      1739   +25     
=========================================
+ Hits          1714      1739   +25     
Impacted Files Coverage Δ
eagerpy/tensor/tensor.py 100.00% <ø> (ø)
eagerpy/framework.py 100.00% <100.00%> (ø)
eagerpy/tensor/jax.py 100.00% <100.00%> (ø)
eagerpy/tensor/numpy.py 100.00% <100.00%> (ø)
eagerpy/tensor/pytorch.py 100.00% <100.00%> (ø)
eagerpy/tensor/tensorflow.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 5eddc85...d781fdb. Read the comment docs.

@jonasrauber
Copy link
Owner

Thanks, this is great! I will add some minor change requests, looking forward to merging it.

Copy link
Owner

@jonasrauber jonasrauber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you make the stylistic changes I commented on (and add the better test)? Besides that, it looks good to me 👍

def topk(
self: TensorType, k: int, sorted: bool = True
) -> Tuple[TensorType, TensorType]:
pair = tf.math.top_k(self.raw, k, sorted=sorted)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use tuple unpacking instead: values, indices = tf.math….
It's more descriptive and avoids tuple indexing in the next line.

def topk(
self: TensorType, k: int, sorted: bool = True
) -> Tuple[TensorType, TensorType]:
pair = self.raw.topk(k, sorted=sorted)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use tuple unpacking instead: values, indices = self.raw….
It's more descriptive and avoids tuple indexing in the next line.

@@ -1206,6 +1206,18 @@ def test_sort(dummy: Tensor) -> Tensor:
return ep.sort(t)


@compare_all
def test_topk_0(dummy: Tensor) -> Tensor:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you rename the tests to end with _values and _indices instead of 0 and 1.

@@ -1206,6 +1206,18 @@ def test_sort(dummy: Tensor) -> Tensor:
return ep.sort(t)


@compare_all
def test_topk_0(dummy: Tensor) -> Tensor:
t = ep.arange(dummy, 27).float32().reshape((3, 3, 3))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good to test on something that's not already sorted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had copied from the tests for sort and argsort where one is already sorted and the other reverse sorted, so I tested values on sorted and indices on reverse sorted. But I've changed it to something more complicated.

@compare_all
def test_topk_0(dummy: Tensor) -> Tensor:
t = ep.arange(dummy, 27).float32().reshape((3, 3, 3))
return ep.topk(t, 2)[0]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you can also use tuple unpacking here:
values, _ = ep.topk(t, 2) and return values
(Same for the other test)

@mglisse
Copy link
Contributor Author

mglisse commented Dec 16, 2020

Hello, are there still changes needed?

@jonasrauber jonasrauber merged commit 9568566 into jonasrauber:master Dec 16, 2020
@jonasrauber
Copy link
Owner

@mglisse Thanks, great work. Merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

topk
2 participants