Skip to content

Commit

Permalink
improved the example
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonas Rauber committed Feb 14, 2020
1 parent 0e3bcac commit 127c063
Showing 1 changed file with 16 additions and 18 deletions.
34 changes: 16 additions & 18 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
.. image:: https://img.shields.io/badge/code%20style-black-000000.svg
:target: https://github.com/ambv/black


.. image:: https://raw.githubusercontent.com/jonasrauber/eagerpy/master/docs/.vuepress/public/logo_small.png
:target: https://jonasrauber.github.io/eagerpy/

:align: right

=======
EagerPy
Expand Down Expand Up @@ -57,34 +56,33 @@ Learn more about in the `documentation <https://jonasrauber.github.io/eagerpy/>`
import numpy as np
x = np.array([1., 2., 3., 4., 5., 6.])
# --------------------------------------------------------------------------
# No matter which framwork you use, you can use the same code
import eagerpy as ep
# Just wrap a native tensor using EagerPy
x = ep.astensor(x)
# All of EagerPy's functionality is available as methods ...
# All of EagerPy's functionality is available as methods
x = x.reshape((2, 3))
norms = x.flatten(start=1).square().sum(axis=-1).sqrt()
norms = x.flatten(start=1)
x.flatten(start=1).square().sum(axis=-1).sqrt()
# or just: x.flatten(1).norms.l2()
# ... and functions
_, grad = ep.value_and_grad(loss_fn, x)
# and as functions (yes, we gradients are also supported!)
loss, grad = ep.value_and_grad(loss_fn, x)
ep.clip(x + eps * grad, 0, 1)
# You can even write functions that work transparently with
# Pytorch tensors, TensorFlow tensors, JAX arrays, NumPy arrays
# and EagerPy tensors
def squared_a_plus_b_times_c(a, b, c):
(a, b, c), restore_type = ep.astensors_(a, b, c)
# here, a, b, c are EagerPyTensors
result = (a + b * c).square()
return restore_type(result)
# You can call this function using any kind of tensors and the result
# will have the same type.
def my_universal_function(a, b, c):
# Convert all inputs to EagerPy tensors
a, b, c = ep.astensors(a, b, c)
# performs some computations
result = (a + b * c).square()
# and return a native tensor
return result.raw
🗺 Use cases
-----------
Expand Down

0 comments on commit 127c063

Please sign in to comment.