Skip to content

Commit

Permalink
Documentation improvements.
Browse files Browse the repository at this point in the history
  • Loading branch information
mblondel committed Jan 19, 2022
1 parent 4ada57a commit 61de873
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 12 deletions.
17 changes: 17 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Unconstrained
:toctree: _autosummary

jaxopt.GradientDescent
jaxopt.LBFGS
jaxopt.ScipyMinimize

Constrained
Expand Down Expand Up @@ -78,6 +79,14 @@ Linear system solving
jaxopt.linear_solve.solve_bicgstab
jaxopt.IterativeRefinement

Nonlinear least squares
-----------------------

.. autosummary::
:toctree: _autosummary

jaxopt.GaussNewton

Root finding
------------

Expand Down Expand Up @@ -107,3 +116,11 @@ Implicit differentiation
jaxopt.implicit_diff.custom_fixed_point
jaxopt.implicit_diff.root_jvp
jaxopt.implicit_diff.root_vjp

Line search
-----------

.. autosummary::
:toctree: _autosummary

jaxopt.BacktrackingLineSearch
1 change: 1 addition & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ New features

- :class:`jaxopt.LBFGS`.
- :class:`jaxopt.BacktrackingLineSearch`.
- :class:`jaxopt.GaussNewton`.

Bug fixes and enhancements
~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
19 changes: 7 additions & 12 deletions jaxopt/_src/gauss_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""GaussNewton algorithm in JAX."""
"""Gauss-Newton algorithm in JAX."""

from typing import Any
from typing import Callable
Expand All @@ -38,22 +38,18 @@ class GaussNewtonState(NamedTuple):
gradient: Any
aux: Optional[Any] = None


@dataclass(eq=False)
class GaussNewton(base.IterativeSolver):
"""Gauss-Newton nonlinear least-squares solver.
This solver finds the optimal parameters via a least-squares optimization
Given the residual function f(x): R^m -> R^n, `gauss_newton` finds a
local minimum of the cost function F(x):
```
argmin_x F(x) = 0.5 * sum(f_i(x)**2), i = 0, ..., m - 1
f(x) = func(x, *args, **kwargs)
```
Given the residual function ``f(x): R^m -> R^n``, where ``f(x) =
residual_fun(x, *args, **kwargs)``, ``GaussNewton`` finds a local minimum of
the cost function ``argmin_x 0.5 * sum(f(x) ** 2)``.
Attributes:
residual_fun: a smooth function of the form ``residual_fun(x, *args, **kwargs)``.
residual_fun: a smooth function of the form
``residual_fun(x, *args, **kwargs)``.
maxiter: maximum number of iterations.
tol: tolerance.
implicit_diff: whether to enable implicit diff or autodiff of unrolled
Expand Down Expand Up @@ -130,7 +126,6 @@ def update(self,
return base.OptStep(params=params, state=state)

def __post_init__(self):

if self.has_aux:
self._fun = lambda *a, **kw: self.residual_fun(*a, **kw)[0]
self._fun_with_aux = self.fun
Expand Down
26 changes: 26 additions & 0 deletions jaxopt/_src/scipy_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,21 @@ class ScipyMinimize(ScipyWrapper):
Attributes:
fun: a smooth function of the form `fun(x, *args, **kwargs)`.
method: the `method` argument for `scipy.optimize.minimize`.
Should be one of
* 'Nelder-Mead'
* 'Powell'
* 'CG'
* 'BFGS'
* 'Newton-CG'
* 'L-BFGS-B'
* 'TNC'
* 'COBYLA'
* 'SLSQP'
* 'trust-constr'
* 'dogleg'
* 'trust-ncg'
* 'trust-exact'
* 'trust-krylov'
tol: the `tol` argument for `scipy.optimize.minimize`.
options: the `options` argument for `scipy.optimize.minimize`.
dtype: if not None, cast all NumPy arrays to this dtype. Note that some
Expand Down Expand Up @@ -374,6 +389,17 @@ class ScipyRootFinding(ScipyWrapper):
`optimality_fun(x, *args, **kwargs)` whose root is to be found. It must
return as output a PyTree with structure identical to x.
method: the `method` argument for `scipy.optimize.root`.
Should be one of
* 'hybr'
* 'lm'
* 'broyden1'
* 'broyden2'
* 'anderson'
* 'linearmixing'
* 'diagbroyden'
* 'excitingmixing'
* 'krylov'
* 'df-sane'
tol: the `tol` argument for `scipy.optimize.root`.
options: the `options` argument for `scipy.optimize.root`.
dtype: if not None, cast all NumPy arrays to this dtype. Note that some
Expand Down

0 comments on commit 61de873

Please sign in to comment.