Skip to content

Commit

Permalink
Merge pull request google#85 from mblondel:version_0.1.1
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 404245958
JAXopt authors committed Oct 19, 2021
2 parents 9a318e6 + f629fe0 commit 9692f90
Showing 4 changed files with 39 additions and 23 deletions.
29 changes: 21 additions & 8 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
Changelog
=========

Version 0.1.1 (development version)
-----------------------------------

- :class:`jaxopt.ArmijoSGD`
- :ref:`sphx_glr_auto_examples_fixed_point_deep_equilibrium_model.py`
- :ref:`sphx_glr_auto_examples_deep_learning_plot_sgd_solvers.py`


Version 0.1.1
-------------

New features
~~~~~~~~~~~~

- Added solver :class:`jaxopt.ArmijoSGD`
- Added example :ref:`sphx_glr_auto_examples_fixed_point_deep_equilibrium_model.py`
- Added example :ref:`sphx_glr_auto_examples_deep_learning_plot_sgd_solvers.py`

Bug fixes
~~~~~~~~~

- Allow non-jittable proximity operators in :class:`jaxopt.ProximalGradient`
- Raise an exception if a quadratic program is infeasible or unbounded

Contributors
~~~~~~~~~~~~

Fabian Pedregosa, Louis Bethune, Mathieu Blondel.

Version 0.1 (initial release)
-----------------------------

14 changes: 8 additions & 6 deletions examples/deep_learning/plot_sgd_solvers.py
Original file line number Diff line number Diff line change
@@ -13,8 +13,8 @@
# limitations under the License.

r"""
Comparison of different GD algorithms.
======================================
Comparison of different SGD algorithms.
=======================================
The purpose of this example is to illustrate the power
of adaptive stepsize algorithms.
@@ -29,9 +29,11 @@
* SGD with constant stepsize
* RMSprop
The reported ``training loss`` is an estimation of the true training loss based on the current minibatch.
This experiment was conducted without momentum, with popular default values for learning rate.
The reported ``training loss`` is an estimation of the true training loss based
on the current minibatch.
This experiment was conducted without momentum, with popular default values for
learning rate.
"""

from absl import flags
@@ -112,7 +114,7 @@ def main(argv):
# manual flags parsing to avoid conflicts between absl.app.run and sphinx-gallery
flags.FLAGS(argv)
FLAGS = flags.FLAGS

train_ds, ds_info = load_dataset(FLAGS.dataset, FLAGS.batch_size)

# Initialize parameters.
17 changes: 9 additions & 8 deletions examples/fixed_point/deep_equilibrium_model.py
Original file line number Diff line number Diff line change
@@ -15,12 +15,13 @@
"""
Deep Equilibrium (DEQ) model in Flax with Anderson acceleration.
================================================================
This implementation is strongly inspired by the Pytorch code snippets in [3].
This implementation is strongly inspired by the Pytorch code snippets in [3].
A similar model called "implicit deep learning" is also proposed in [2].
In practice BatchNormalization and initialization of weights in convolutions are
important to ensure convergence.
important to ensure convergence.
[1] Bai, S., Kolter, J.Z. and Koltun, V., 2019. Deep Equilibrium Models.
Advances in Neural Information Processing Systems, 32, pp.690-701.
@@ -136,7 +137,7 @@ def block_apply(z, x, block_params):
solver = self.fixed_point_solver(fixed_point_fun=block_apply)
def batch_run(x, block_params):
return solver.run(x, x, block_params)[0]

return jax.vmap(batch_run, in_axes=(0,None), out_axes=0)(x, block_params)


@@ -147,7 +148,7 @@ class FullDEQ(nn.Module):
fixed_point_solver: Callable

@nn.compact
def __call__(self, x, train):
def __call__(self, x, train):
x = nn.Conv(features=self.channels, kernel_size=(3,3), use_bias=True, padding='SAME')(x)
x = nn.BatchNorm(use_running_average=not train, momentum=0.9, epsilon=1e-5)(x)
block = ResNetBlock(self.channels, self.channels_bottleneck)
@@ -261,7 +262,7 @@ def jitted_update(params, state, batch_stats, data):
batch_stats = state.aux['batch_stats']
print_accuracy(params, state)
params, state = jitted_update(params, state, batch_stats, next(train_ds))


if __name__ == "__main__":
app.run(main)
2 changes: 1 addition & 1 deletion jaxopt/version.py
Original file line number Diff line number Diff line change
@@ -14,4 +14,4 @@

"""JAXopt version."""

__version__ = "0.1"
__version__ = "0.1.1"

0 comments on commit 9692f90

Please sign in to comment.