You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This repository contains code for training and evaluating the models in the paper Likelihood-Based Diffusion Language Models.
Installing requirements
This codebase requires PyTorch 2.0 and a few fused CUDA kernels that need to be installed manually. Most of the dependencies can be installed automatically:
pip install -r requirements.txt
Install FlashAttention with fused MLP and rotary embedding kernels:
This repository supports computing zero-shot likelihoods on six datasets: Penn TreeBank, enwik8, text8, WikiText2, WikiText103, and the 1 Billion Word corpus.
To compute likelihood for one of these datasets, specify the dataset path in the corresponding constant at the top of lib/datasets.py. Then run this command (e.g. for WikiText103):
Update the OPENWEBTEXT2_DATA_DIR constant in lib/datasets.py with the path to the extracted files.
Run the OpenWebText2 preprocessing script:
python -m misc.owt2_preprocess
Run the training script:
python train.py
By default, this trains a small model (16 layers, dim 384, sequence length 256, 92K steps at batch size 256) which should take under a day on an 80GB A100. You can change these hyperparameters by passing different options to train.py.
If you don't have enough memory to train the model with default settings, you can enable gradient accumulation. The following commands should produce equivalent results: