You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Diffusion-RWKV: Scaling RWKV-Like Architectures for Diffusion Models Official PyTorch Implementation
This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper scalable diffusion models with RWKV-like architectures, named Diffusion-RWKV.
It builds a series of architectures adapted from the RWKV model used in the NLP, with requisite modifications tailored for diffusion model applied to image generation tasks.
1. Environments
Python 3.10
conda create -n your_env_name python=3.10
Requirements file
pip install -r requirements.txt
Install mmcv-full and mmcls
pip install -U openmim
mim install mmcv-full==1.7.0
pip install mmcls==0.25.0
2. Training
We provide a training script for Diffusion-RWKV in train.py. This script can be used to train unconditional, class-conditional Diffusion-RWKV models, it can be easily modified to support other types of conditioning.
To launch DRWKV-H/2 (256x256) in the latent space training with N GPUs on one node:
There are several additional options; see train.py for details.
Experiments for training script can be found in the file direction script.
For convenience, the pre-trained Diffusion-RWKV models can be directly downloaded in huggingface.
3. Evaluation
We include a sample.py script which samples images from a Diffusion-RWKV model. Besides, we support other metrics evaluation, e.g., FLOPS and model parameters, in test.py script.