You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Optax is a gradient processing and optimization library for JAX.
Optax is designed to facilitate research by providing building blocks
that can be easily recombined in custom ways.
Our goals are to
Provide simple, well-tested, efficient implementations of core components.
Improve research productivity by enabling to easily combine low-level
ingredients into custom optimizers (or other gradient processing components).
Accelerate adoption of new ideas by making it easy for anyone to contribute.
We favor focusing on small composable building blocks that can be effectively
combined into custom solutions. Others may build upon these basic components
in more complicated abstractions. Whenever reasonable, implementations prioritize
readability and structuring code to match standard equations, over code reuse.
An initial prototype of this library was made available in JAX's experimental
folder as jax.experimental.optix. Given the wide adoption across DeepMind
of optix, and after a few iterations on the API, optix was eventually moved
out of experimental as a standalone open-source library, and renamed optax.
Optax contains implementations of many popular optimizers and
loss functions.
For example, the following code snippet uses the Adam optimizer from optax.adam
and the mean squared error from optax.l2_loss. We initialize the optimizer
state using the init function and params of the model.
optimizer=optax.adam(learning_rate)
# Obtain the `opt_state` that contains statistics for the optimizer.params= {'w': jnp.ones((num_weights,))}
opt_state=optimizer.init(params)
To write the update loop we need a loss function that can be differentiated by
Jax (with jax.grad in this
example) to obtain the gradients.
The gradients are then converted via optimizer.update to obtain the updates
that should be applied to the current parameters to obtain the new ones.
optax.apply_updates is a convenience utility to do this.
We welcome issues reports and pull requests solving issues or improving
existing functionalities. If you are interested in adding a feature like a new
optimizer, open an issue first! We are focused on making optax more
flexible, versatile and easy to use for you to define your own optimizers.
Source code
You can check the latest sources with the following command.
optimistix: nonlinear solvers: root finding, minimisation, fixed points, and least squares.
matfree:
matrix free methods useful to study curvature dynamics in deep learning.
Citing Optax
This repository is part of the DeepMind JAX Ecosystem, to cite Optax
please use the citation:
@software{deepmind2020jax,
title = {The {D}eep{M}ind {JAX} {E}cosystem},
author = {DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\'{c}, Milo\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio},
url = {https://github.com/google-deepmind},
year = {2020},
}
About
Optax is a gradient processing and optimization library for JAX.