Skip to content

PyTorch, TensorFlow, JAX and NumPy — all of them natively using the same code

License

Notifications You must be signed in to change notification settings

sandrlom/eagerpy

Repository files navigation

EagerPy

https://jonasrauber.github.io/eagerpy/logo.png

What is EagerPy?

EagerPy is a Python framework that let's you write code that automatically works natively with PyTorch, TensorFlow, JAX, and NumPy.

EagerPy is also great when you work with just one framework but prefer a clean and consistent NumPy-inspired API that is fully chainable, provides extensive type annotions and let's you write beautiful code. It often combines the best of PyTorch's API and NumPy's API.

Design goals

  • Native Performance: EagerPy operations get directly translated into the corresponding native operations.
  • Fully Chainable: All functionality is available as methods on the tensor objects and as EagerPy functions.
  • Type Checking: Catch bugs before running your code thanks to EagerPy's extensive type annotations.

Documentation

Learn more about in the documentation.

Use cases

Foolbox Native, the latest version of Foolbox, a popular adversarial attacks library, has been rewritten from scratch using EagerPy instead of NumPy to achieve native performance on models developed in PyTorch, TensorFlow and JAX, all with one code base.

Installation

pip install eagerpy

Example

import torch
x = torch.tensor([1., 2., 3., 4., 5., 6.])

import tensorflow as tf
x = tf.constant([1., 2., 3., 4., 5., 6.])

import jax.numpy as np
x = np.array([1., 2., 3., 4., 5., 6.])

import numpy as np
x = np.array([1., 2., 3., 4., 5., 6.])

# --------------------------------------------------------------------------

# No matter which framwork you use, you can use the same code
import eagerpy as ep

# Just wrap a native tensor using EagerPy
x = ep.astensor(x)

# All of EagerPy's functionality is available as methods ...
x = x.reshape((2, 3))
norms = x.flatten(start=1).square().sum(axis=-1).sqrt()
norms = x.flatten(start=1)

# ... and functions
_, grad = ep.value_and_grad(loss_fn, x)
ep.clip(x + eps * grad, 0, 1)

# You can even write functions that work transparently with
# Pytorch tensors, TensorFlow tensors, JAX arrays, NumPy arrays
# and EagerPy tensors
def squared_a_plus_b_times_c(a, b, c):
   (a, b, c), restore_type = ep.astensors_(a, b, c)
   # here, a, b, c are EagerPyTensors
   result = (a + b * c).square()
   return restore_type(result)

# You can call this function using any kind of tensors and the result
# will have the same type.

Compatibility

We currently test with the following versions:

  • PyTorch 1.4.0
  • TensorFlow 2.1.0
  • JAX 0.1.57
  • NumPy 1.18.1

About

PyTorch, TensorFlow, JAX and NumPy — all of them natively using the same code

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python 98.8%
  • Makefile 1.2%