You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
lots of different solvers (including Tsit5, Dopri8, symplectic solvers, implicit solvers);
vmappable everything (including the region of integration);
using a PyTree as the state;
dense solutions;
multiple adjoint methods for backpropagation;
support for neural differential equations.
From a technical point of view, the internal structure of the library is pretty cool -- all kinds of equations (ODEs, SDEs, CDEs) are solved in a unified way (rather than being treated separately), producing a small tightly-written library.
Always useful Equinox: neural networks and everything not already in core JAX! jaxtyping: type annotations for shape/dtype of arrays.
Deep learning Optax: first-order gradient (SGD, Adam, ...) optimisers. Orbax: checkpointing (async/multi-host/multi-device). Levanter: scalable+reliable training of foundation models (e.g. LLMs). paramax: parameterizations and constraints for PyTrees.
Scientific computing Optimistix: root finding, minimisation, fixed points, and least squares. Lineax: linear solvers. BlackJAX: probabilistic+Bayesian sampling. sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent. PySR: symbolic regression. (Non-JAX honourable mention!)
Awesome JAX Awesome JAX: a longer list of other JAX projects.