Deploying long-context LLMs is costly due to the linear growth of the key-value (KV) cache in transformer models. For example, handling 1M tokens with Llama 3.1-70B in float16 requires up to 330GB of memory. kvpress implements multiple KV cache compression methods and benchmarks using 🤗 transformers, aiming to simplify the development of new methods for researchers and developers in this field.
pip install kvpress
For a local installation with all dev dependencies, use uv:
git clone https://github.com/NVIDIA/kvpress.git
cd kvpress
uv sync --all-groups
Advanced installation settings
To install optional packages, you can use uv. To install with flash attention, just run:
git clone https://github.com/NVIDIA/kvpress.git
cd kvpress
uv sync --extra flash-attn
To install with dependencies for evaluation, run
git clone https://github.com/NVIDIA/kvpress.git
cd kvpress
uv sync --extra eval
KVPress provides a set of "presses" that compress the KV cache during the prefilling-phase. Each press is associated with a compression_ratio
attribute that measures the compression of the cache. The easiest way to use a press is through our custom KVPressTextGenerationPipeline
. It is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported and handles chat templates and tokenization for you:
from transformers import pipeline
from kvpress import ExpectedAttentionPress
device = "cuda:0"
model = "meta-llama/Llama-3.1-8B-Instruct"
model_kwargs = {"attn_implementation": "flash_attention_2"}
pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs)
context = "A very long text you want to compress once and for all"
question = "\nA question about the compressed context" # optional
press = ExpectedAttentionPress(compression_ratio=0.5)
answer = pipe(context, question=question, press=press)["answer"]
In the snippet above, the compression is only applied on the context tokens so that you can evaluate the compression for different questions. Check the Wikipedia notebook demo for a more detailed example (also available on Colab here).
Decoding Compression
By default, KVPress applies compression during the pre-filling phase. As a new (experimental) feature, we now support decoding compression via the `DecodingPress` wrapper. `DecodingPress` compresses the KV cache periodically during token generation, optionally maintaining a buffer of recent hidden states. `DecodingPress` supports the following parameters:base_press
: Any ScorerPress (e.g.,KNormPress
,CriticalKVPress
)compression_interval
: Steps between compressions (default: 10)target_size
: Target cache size of the cache after compression (default: 1024)hidden_states_buffer_size
: Number of hidden states to buffer before compression (default: 128). Some presses don't need buffered hidden states and can set this to 0.
Unlike a compression ratio, decoding press uses a target_size
to compress the cache. This means that the cache is compressed every compression_interval
steps, and the compression ratio is automatically computed such that the size of the cache after compression equals target_size
.
An example for decoding compression:
from transformers import pipeline
from kvpress import KnormPress
from kvpress import DecodingPress
# Initialize the pipeline
device = "cuda:0"
model = "meta-llama/Llama-3.1-8B-Instruct"
model_kwargs = {"attn_implementation": "flash_attention_2"}
pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs)
# Create a decoding press that compresses every 10 steps to 512 tokens
decoding_press = DecodingPress(
base_press=KnormPress(),
compression_steps=10,
token_buffer_size=512
)
# Use with pipeline
context = "A very long text you want to compress during generation"
question = "Tell me a long story about this context"
response = pipe(context, question=question, press=decoding_press)["answer"]
Not all existing presses are fully compatible with DecodingPress due to fundamental differences in how compression works during decoding versus prefilling. in particular, we only support ScorerPresses as base presses.
All current presses are training free and inherit from BasePress
(source).
Several presses inherit from ScorerPress
(source) and rely on a score to prune the KV pairs with lowest importance:
RandomPress
(source): random scoreKnormPress
(source, paper): inverse norm of the keySnapKVPress
(source, paper): average attention weight of the last queriesExpectedAttentionPress
(source, notebook): expected attention weight during the generation phaseStreamingLLMPress
(source, paper): keep only the initial and recent tokensTOVAPress
(source, paper): attention weight of the last query averaged across headsObservedAttentionPress
(source, paper): average attention weight observed during in pre-filling phaseQFilterPress
(source, paper): project the Key representations on the main SVD component of the Query vectors to approximate the attention scores.PyramidKVPress
(source, paper): maintain pyramid-like cache sizes, allocating more cache budget to lower layers and less to higher layersLagKVPress
(source, paper): leverage on the KV lag-relative information to compress. It's query free, attention-weight free, and flash-attention compatible.KeyDiffPress
(source, paper): evicts tokens based solely on key similarity.
Some presses rely on a different logic:
ThinKPress
(source, paper): compress the dimensions of the keys based on the channel attention score on the last queriesSimLayerKVPress
(source, paper): identify "lazy" layers, and apply the StreamingLLM approach to themDuoAttentionPress
(source, paper): split heads into retrieval heads (no compression) and streaming heads (StreamingLLM approach)FinchPress
(source, paper): similar to SnapKV with a dynamic window size and key value re-rotationKVzipPress
(source, paper): identifies redundant KV pairs through context reconstruction. Achieves near-lossless compression at the cost of multiple forward passes.
Finally we provide wrapper presses that can be combined with other presses:
AdaKVPress
(source, paper): prune bottom scores of anyScorerPress
but across all heads, achieving head-wise compressionsPerLayerCompressionPress
(source): compress each layer with a different compression ratio (experimental)ComposedPress
(source): compose multiple presses together by chaining their forward hooksKeyRerotationPress
(source): rerotate pruned keys to have continuous RoPE embeddingsChunkKVPress
(source, paper): compresses by selecting important chunks, preserving semantic coherenceChunkPress
(source, paper): compress the KV cache on each sequence chunk separately. This can yield to more uniform compression across long sequencesCriticalKVPress
andCriticalAdaKVPress
(source, paper): refine the scores using the L1 norm of Wo @ values, coupled with a two-stage selection.BlockPress
(source, paper): segments input sequence into non-overlapping blocks and compresses iteratively.DeocdingPress
(source): Allows for compression during decoding, see decoding section in this README.PrefillDecodingPress
(source): Allows to compress both during prefilling and during decoding.
For a detailed list of existing KV cache compression methods, check Awesome-KV-Cache-Compression or Awesome-LLM-Compression
We provide a simple CLI to evaluate the performance of different presses on several long-context datasets.
-
Accuracy: Test your method on popular benchmarks directly using our CLI. For a broader comparison, check out our public Hugging Face Leaderboard , where you can see how various methods stack up against each other.
-
Speed and Memory: The speed_and_memory notebook can help you measure peak memory usage and total time gain.
Please refer to the evaluation directory in this repo for more details and results.
Below we report the average performance on the RULER dataset with 4k context length for different presses, from our
We support KV cache quantization through the transformers QuantizedCache
class (see HF blog post). To use it, simply pass a cache object to your pipeline:
from transformers import QuantizedCacheConfig, QuantoQuantizedCache
config = QuantizedCacheConfig(nbits=4)
cache = QuantoQuantizedCache(config)
pipe(..., cache=cache)
By default, the DynamicCache
is used (no quantization).
Important
To use the QuantizedCache
, you need to install additional dependencies (e.g. pip install optimum-quanto
).
We welcome contributions! To add a new press, simply open an issue or submit a pull request. Check the new_press.ipynb notebook for a step-by-step guide.
Some presses depend on the model architecture (e.g. ExpectedAttentionPress
or SnapKVPress
) hence they might not work with all models. We tested support for LlamaForCausalLM
, MistralForCausalLM
, Phi3ForCausalLM
, Qwen2ForCausalLM
, Qwen3ForCausalLM
, and Gemma3ForCausalLM
but many other models might be supported out of the box because their implementation is often similar in transformers.
kvpress supports multi-GPU inference through accelerate:
pipe = pipeline("kv-press-text-generation", model=model, device_map="auto")
Memory usage should be reduced by around compression_ratio * kv_cache_size
. As the KV cache is smaller, decoding should also be faster. You can measure peak memory usage gain and total time gain using this notebook.
A press registers a forward hook (press.forward_hook
method) to each attention layer during the pre-filling phase. Registration can be applied using the press as a context manager (press.__call__
method):
import torch
from transformers import AutoModelForCausalLM
from kvpress import KnormPress
device = "cuda:0"
ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(ckpt).to(device)
press = KnormPress(compression_ratio=0.4)
inputs = model.dummy_inputs["input_ids"].to(device)
with torch.no_grad():
print(model(inputs).past_key_values[0][0].shape)
# torch.Size([3, 8, 5, 128])
with torch.no_grad(), press(model):
print(model(inputs).past_key_values[0][0].shape)
# torch.Size([3, 8, 3, 128])
In fact you can use model.generate
with a press by using the press as a context manager:
with press(model):
outputs = model.generate(inputs)
However, the generate
method does not allow to exclude the question from the compression, which would artificially favors methods such as SnapKV. Ideally, we want a compression method that works whatever comes after the context (e.g. for use cases such as chat or document question answering). Finally the generate
method does not allow to provide generation for multiple questions at once.
Combines separate presses for prefilling and decoding phases.
Parameters:
prefilling_press
: Press used during prefill phasedecoding_press
: Press used during decoding phase
from transformers import pipeline
from kvpress import KnormPress
from kvpress import DecodingPress
# Initialize the pipeline
device = "cuda:0"
model = "meta-llama/Llama-3.1-8B-Instruct"
model_kwargs = {"attn_implementation": "flash_attention_2"}
pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs)
# Create a decoding press that compresses every 10 steps to 512 tokens
decoding_press = DecodingPress(
base_press=KnormPress(),
compression_steps=10,
token_buffer_size=512
)
# Use with pipeline
context = "A very long text you want to compress during generation"
question = "Tell me a long story about this context"
response = pipe(context, question=question, press=decoding_press)["answer"]
from transformers import pipeline
from kvpress import CriticalKVPress, KnormPress
from kvpress import DecodingPress, PrefillDecodingPress
# Initialize the pipeline
device = "cuda:0"
model = "meta-llama/Llama-3.1-8B-Instruct"
model_kwargs = {"attn_implementation": "flash_attention_2"}
pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs)
# Different strategies for prefill vs decoding
prefill_press = CriticalKVPress(KnormPress())
decoding_press = DecodingPress(
base_press=KnormPress(compression_ratio=0.2),
compression_steps=5,
token_buffer_size=256
)
# Combine them
combined_press = PrefillDecodingPress(
prefilling_press=prefill_press,
decoding_press=decoding_press
)
context = "A very long context that will be compressed during prefill"
question = "Generate a detailed analysis that will be compressed during decoding"
response = pipe(context, question=question, press=combined_press)["answer"]