High-performance, extensible, chainable optimizers for PyTorch.
Major Release: HeavyBall 2.0.0 introduces comprehensive documentation overhaul, enhanced interactive visualizations, and new optimizers including the Chainable Architecture and Schedule-Free Optimizers. This release features robust testing, stability improvements, and detailed theory with practical examples.
π Read the full release notes | π Quick Start Guide
- Lightning-Fast Training: Batched
foreach
operations deliver significant speedups on large models. - Adaptive & Extensible: Built-in AdamW, RMSprop, Schedule-Free algorithms, and PaLM-inspired schedules.
- Plug-and-Play: Drop-in replacements for
torch.optim
with seamless integration. - Customizable: Chainable API lets you compose optimizers and transforms (MARS correction, cautious updates, orthogonal updates).
- Battle-Tested: Extensive benchmarks and real-world examples included.
- New in v2.0.0: Foreach-optimized PSGD variants (
ForeachPSGDKron
,ForeachCachedPSGDKron
) with substantial speedups - New in v2.0.0: Schedule-Free optimizers (
ForeachSFAdamW
,SFAdaGrad
) that eliminate learning rate scheduling - New in v2.0.0: Chainable architecture for composing complex optimization pipelines
- Foreach-based optimizers:
ForeachAdamW
,ForeachRMSprop
,ForeachSFAdamW
,Muon
,ADOPT
,MSAM
, β¦ - Advanced update rules: MARS correction, cautious updates, PaLM beta2 scheduling
- Enhanced BF16/FP16 support for mixed-precision training
- Comprehensive benchmark suite and interactive visualizations
- Detailed documentation with theoretical foundations and practical examples
Install:
pip install heavyball
Basic usage:
import torch
from torch import nn
from heavyball import ForeachAdamW
model = nn.Sequential(
nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 10)
)
optimizer = ForeachAdamW(model.parameters(), lr=1e-3)
for data, target in dataloader:
optimizer.zero_grad()
output = model(data)
loss = torch.nn.functional.cross_entropy(output, target)
loss.backward()
optimizer.step()
Experience HeavyBall optimizers in action with our interactive visualization:
- Open the demo: Simply open
index.html
in your web browser - View performance: Watch real-time optimizer trajectory visualization showing loss convergence
- Export results: Save visualizations as PNG images with custom filenames
- Run tests: Use
npm test
to run the Playwright test suite for the demo
The interactive demo provides an intuitive way to understand optimizer behavior and performance characteristics without writing code.
Reproduce benchmarks with:
python3 -m benchmark.run_all_benchmarks --opt ForeachSOAP --opt LaProp --opt AdamW --opt Muon --opt ForeachCachedNewtonPSGD --opt RMSprop --opt OrthoLaProp --opt ForeachSFAdamW --opt ForeachADOPT --opt LaPropOrtho --opt CachedPSGDKron --opt SignLaProp --opt ForeachSOLP --opt PSGDLRA --opt NewtonPSGDLRA --opt NewtonHybrid2PSGDKron --opt NewtonHybrid2PSGDLRA --opt mars-NewtonHybrid2PSGDLRA --opt MSAMLaProp --opt mars-adaptive-NewtonHybrid2PSGDKron --opt mars-ortho-NewtonHybrid2PSGDKron --opt MuonLaProp --opt mars-unscaled-NewtonHybrid2PSGDKron --opt mars-NewtonHybrid2PSGDKron --opt cautious-AdamW --opt unscaled_cautious-AdamW --opt mars-AdamW --dtype float32 --steps 1000000 --trials 1000 --parallelism 256 --seeds 1 --difficulties trivial --difficulties easy --difficulties medium --difficulties hard --difficulties extreme --difficulties nightmare --timeout 2880
We welcome contributions! Please check the issue tracker and follow these steps:
- Fork the repo and create a feature branch.
- Install dev dependencies:
pip install -e .[dev]
. - Run tests:
pytest
. - Submit a pull request.
BSD 3-Clause β see the LICENSE file.
Made by the HeavyBall team.