You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Dec 29, 2022. It is now read-only.
To run the example notebooks, please first pip install tensorflow_datasets.
TensorFlow Implementation
The main class that runs distributed Lanczos algorithm is LanczosExperiment. The Jupyter notebook demonstrates how to use this class.
In addition to single machine (potentially multiple-GPU setups), this implementation is also suitable for multi-GPU multi-worker setups. The crucial step is manually partitioning the input data across the available GPUs.
The algorithm outputs two numpy files: tridiag_1 and lanczos_vec_1 which are the tridiagonal matrix and Lanczos vectors. The tridiagonal matrix can then be used to generate spectral densities using tridiag_to_density.
The Jax version is fantastic for fast experimentation (especially in conjunction with trax). The Jupyter notebook demonstrates how to run Lanczos in Jax.
The main function is lanczos_alg, which returns a tridiagonal matrix and Lanczos vectors. The tridiagonal matrix can then be used to generate spectral densities using tridiag_to_density.
Differences between implementations
The TensorFlow version performs Hessian-vector product accumulation and the actual Lanczos algorithm in float64, whereas the Jax version performs all calculation in float32.
The TensorFlow version targets multi-worker distributed setups, whereas the Jax version targets single worker (potentially multi-GPU) setups.