You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
importjax.numpyasjnpimportoptimistixasoptx# Let's solve the ODE dy/dt=tanh(y(t)) with the implicit Euler method.# We need to find y1 s.t. y1 = y0 + tanh(y1)dt.y0=jnp.array(1.)
dt=jnp.array(0.1)
deffn(y, args):
returny0+jnp.tanh(y) *dtsolver=optx.Newton(rtol=1e-5, atol=1e-5)
sol=optx.fixed_point(fn, solver, y0)
y1=sol.value# satisfies y1 == fn(y1)
Citation
If you found this library to be useful in academic work, then please cite: (arXiv link)
@article{optimistix2024,
title={Optimistix: modular optimisation in JAX and Equinox},
author={Jason Rader and Terry Lyons and Patrick Kidger},
journal={arXiv:2402.09983},
year={2024},
}
See also: other libraries in the JAX ecosystem
Always useful Equinox: neural networks and everything not already in core JAX! jaxtyping: type annotations for shape/dtype of arrays.
Deep learning Optax: first-order gradient (SGD, Adam, ...) optimisers. Orbax: checkpointing (async/multi-host/multi-device). Levanter: scalable+reliable training of foundation models (e.g. LLMs). paramax: parameterizations and constraints for PyTrees.
Awesome JAX Awesome JAX: a longer list of other JAX projects.
Credit
Optimistix was primarily built by Jason Rader (@packquickly): Twitter; GitHub; Website. It is being co-maintained by Johanna Haffner (@johannahaffner): GitHub; Website.