The evaluation framework for training-free sparse attention in LLMs
This repository serves two main purposes:
- Reproducing results from our paper "The Sparse Frontier: Sparse Attention Trade-offs in Transformer LLMs".
- Providing a starting point for your own training-free sparse attention research and development.
The Problem: vLLM is a highly optimized framework supporting hundreds of models, but its extensive codebase makes integrating custom sparse attention patterns extremely challenging. Researchers face a difficult choice: build from scratch with limited model support, use Hugging Face where each model requires navigating different implementation files to add support, or navigate vLLM's complex internals.
Our Solution: We provide a clean abstraction that lets you focus on your sparse attention logic while automatically inheriting all of vLLM's optimizations and model compatibility. Here's what makes our framework unique:
- 🎯 Elegant vLLM Integration: Seamless sparse attention integration through our
AttentionHandler
that intercepts vLLM's execution flow. Write your sparse attention in 50 lines, not 5000—evaluate on 100 models, not 1. By implementing sparse attention in our framework, you automatically gain compatibility with all models supported by vLLM, from small 7B models to large 405B+ models across different architectures (Llama, Qwen, Mistral, etc.). - ⚡ State-of-the-art Baselines: 6 representative SOTA patterns spanning key design dimensions for both inference phases—sparse prefilling (Vertical-Slash, Block-Sparse, FlexPrefill), sparse decoding (Quest), and KV cache compression (SnapKV, Ada-SnapKV)—with optimized Triton implementations.
- 🔬 Comprehensive Evaluation: 9 diverse tasks covering retrieval, multi-hop reasoning, and information aggregation with rigorous sequence length control and standardized preprocessing.
- đź§Ş Research-Grade Extensibility: Clean modular architecture with abstract base classes designed for rapid prototyping of novel sparse attention patterns and tasks.
If you're new to sparse attention and want to understand how these patterns work, we recommend starting with our companion repository: nano-sparse-attention. It provides clean, educational PyTorch implementations with interactive Jupyter notebooks for experimenting and learning before diving into the optimized implementations here. Originally created for the NeurIPS 2024 Dynamic Sparsity Workshop, it serves as an excellent starting point for understanding sparse attention fundamentals.
Follow these steps to set up the environment and prepare for running experiments:
-
Create Virtual Environment and Install Dependencies: Set up a dedicated Python environment and install the required packages, including compiling custom CUDA kernels
# Create a virtual environment using Python 3.11 python3.11 -m venv .venv # Activate the virtual environment source .venv/bin/activate # Upgrade pip and install essential build/utility tools pip install --no-cache-dir --upgrade pip setuptools wheel psutil ninja # Install PyTorch pip install --no-cache-dir torch==2.5.1 # Install the sparse_frontier project in editable mode pip install --no-cache-dir -e . # Compile custom CUDA kernels (for MInference attention) # Adjust MAX_JOBS based on your system core count for faster compilation MAX_JOBS=8 python compile.py build_ext --inplace --build-lib ./sparse_frontier/modelling/attention/minference
For reference, the complete list of dependencies used in our experiments is available in
./assets/pipfreeze.txt
. We tested the codebase on both A100 and H100 GPUs. -
Configure Paths: Modify the default configuration file to specify where data, results, and checkpoints should be stored on your system.
- Edit the
paths
section inconfigs/default.yaml
.
- Edit the
-
Download Model Checkpoints: Obtain the pre-trained model weights you intend to evaluate from Hugging Face Hub. We prefer doing this manually as this way we have better control of what and where we download things.
- Ensure the final directory structure for the downloaded checkpoints matches the format expected by the corresponding model configuration file (e.g., as defined in
configs/model/qwen_7b.yaml
). Themodel.path
variable in these configs should point to the directory containing the model files.
- Ensure the final directory structure for the downloaded checkpoints matches the format expected by the corresponding model configuration file (e.g., as defined in
Experiments are launched using the main script sparse_frontier.main
. The framework uses Hydra for configuration management. All configurations are stored in YAML files within the sparse_frontier/configs/
directory, organized into three main categories:
attention/
: Configurations for different sparse attention mechanisms (dense, quest, snapkv, etc.)task/
: Configurations for evaluation tasks (RULER, QA, Story tasks)model/
: Configurations for different model architectures (Qwen2.5-7B to 72B)
The execution pipeline typically involves three stages, controlled by the mode
parameter (defaulting to all
):
- Preparation (
preparation.py
): Generates and saves task-specific data based on the selectedtask
configuration. Tasks are defined insparse_frontier/tasks/
(inheriting from AbstractTask and AbstractSample) and registered in TASK_REGISTRY. - Prediction (
prediction.py
): Runs the specifiedmodel
with the chosenattention
mechanism on the prepared data, saving the model outputs. Attention mechanisms are implemented insparse_frontier/modelling/attention/
and registered in ATTENTION_REGISTRY. - Evaluation (
evaluation.py
): Compares the predictions against the gold answers using the task's specific evaluation logic and saves the final metrics.
Quick Start Examples:
# Basic experiment with command line overrides
python -m sparse_frontier.main task=ruler_niah model=qwen_7b attention=quest samples=50
# Override attention parameters
python -m sparse_frontier.main attention=quest attention.args.token_budget=2048
For detailed configuration documentation see CONFIGURATION.md.
Note: The current framework implementation supports only batch size = 1. This limitation stems from our initial experiments with methods that had kernels supporting only BS=1. Since then, we have followed a simple heuristic: for a given (Model Size, Method, Sequence Length) combination, we find the minimum tensor parallelism (TP) that provides sufficient total GPU memory to handle the evaluation, then use our intra-node scheduler to distribute BS=1 evaluations across the node's GPUs. For the majority of our evaluations, we achieved >95% GPU utilization. Nevertheless, higher throughput and GPU utilization could likely be achieved with larger TP and BS>1. We plan to support batch size > 1 in the next release.
Instead of wrestling with vLLM's complex internals, we provide a clean abstraction layer that lets you focus on your sparse attention logic.
Our integration works by intercepting vLLM's attention execution at the FlashAttention level. When you register your sparse attention pattern, our framework:
- Patches vLLM's FlashAttention forward method - The
vllm_patched_forward
function in vllm_model.py replaces vLLM's default attention computation. - Routes attention calls through our handler - The
AttentionHandler
from handler.py manages layer state, prefill vs decode phases, and KV cache updates. - Executes your sparse attention - Your implementation receives the same tensors vLLM would use, but with your custom attention logic.
The swap_vllm_attention
function is registered as a vLLM plugin in setup.py, ensuring all tensor parallel workers automatically load our custom implementation. This provides seamless Tensor Parallelism support without any additional configuration.
The integration is automatic - when you register your attention pattern, it becomes available to all vLLM-compatible models without any additional setup.
from sparse_frontier.modelling.attention.abstract_attention import AbstractAttention
class MySparseAttention(AbstractAttention):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Your initialization logic
def __call__(self, queries, keys, values, layer_idx):
# Your prefill attention logic (uses dense prefill if not implemented)
return attention_output
def decode(self, query, keys, values, k_cache, v_cache, cache_seqlens, output, layer_idx):
# Your decoding logic (uses dense decoding if not implemented)
pass
def kv_compress(self, queries, keys, values):
# Your KV compression logic (leaves the KV Cache intact if not implemented)
return compressed_keys, compressed_values, seq_lens
That's it! No need to browse the huge vLLM codebase or worry about inference state handling, etc.
Examples can be found in: kv_compression.py for SnapKV and AdaSnapKV; efficient_prefilling.py for Vertical-Slash, Block-Sparse, and FlexPrefill, efficient_decoding.py for Quest.
Once you've implemented your sparse attention pattern, registration is a simple two-step process:
1. Register in the Attention Registry
Add your attention class to the ATTENTION_REGISTRY
dictionary in sparse_frontier/modelling/attention/registry.py
:
from .your_module import MySparseAttention # Import your implementation
ATTENTION_REGISTRY = {
...
'my_sparse_attention': MySparseAttention, # Add your pattern here
}
2. Create Configuration File
Create a YAML configuration file at configs/attention/my_sparse_attention.yaml
that defines your attention mechanism and its parameters:
# @package _global_
attention:
name: my_sparse_attention
args:
sparsity_ratio: 0.1
The configuration structure follows the pattern used by existing attention mechanisms. The name
field must match your registry key, and args
contains all the parameters that will be passed to your attention class constructor.
3. Run Evaluation
Your sparse attention pattern is now ready to use! Test it with any model and task:
# Basic evaluation with your new attention pattern
python -m sparse_frontier.main task=ruler_niah model=qwen_7b attention=my_sparse_attention
# Override specific attention parameters
python -m sparse_frontier.main attention=my_sparse_attention attention.args.sparsity_ratio=0.05
Result: Your sparse attention works with any vLLM-compatible model and benefits from all vLLM optimizations.
Experimental data generation is handled by task-specific modules located in sparse_frontier/tasks/
. Each task implements AbstractTask
and AbstractSample
subclasses (defined in sparse_frontier/tasks/abstract_*.py
) to define input / output creation, and task-specific evaluation. Tasks are registered in sparse_frontier/tasks/registry.py
and selected via configuration (e.g., task=your_task_name
). The preparation.py
script orchestrates the generation process based on the configuration, saving the formatted samples. See existing tasks like QATask
(qa_task.py
) or the Ruler tasks (niah.py
, cwe.py
, vt.py
) for implementation examples.
In this repository, we evaluate 6 sparse attention patterns:
Pattern | Source |
---|---|
Vertical-Slash / Block-Sparse | Microsoft |
FlexPrefill | ByteDance-Seed |
SnapKV | FasterDecoding |
Ada-SnapKV | FFY0 |
Quest | MIT-HAN-Lab |
We either re-implement these patterns based on the original code or borrow implementations including kernels (for Vertical-Slash and Block-Sparse) from MInference.
Our evaluation framework includes the following tasks:
-
RULER Tasks: Re-implementation of NIAH, VT, and CWE tasks from NVIDIA/RULER
-
QA Tasks:
- Toefl and Quality datasets from LC-VS-RAG
- SQuAD dataset from NVIDIA/RULER
-
Novel Story Tasks: Narrative tasks developed specifically for this project.
If you found the repository useful consider citing the paper about this work.
@article{nawrot2025sparsefrontier,
title={The Sparse Frontier: Sparse Attention Trade-offs in Transformer LLMs},
author={Piotr Nawrot and Robert Li and Renjie Huang and Sebastian Ruder and Kelly Marchisio and Edoardo M. Ponti},
year={2025},
journal={arXiv:2504.17768}
url={https://arxiv.org/abs/2504.17768},
}
If you have any questions, feel free to raise a Github issue or contact me directly at: piotr.nawrot@ed.ac.uk