This repo contains the code for the paper "Towards Economical Inference: Enabling DeepSeek's Multi-Head Latent Attention in Any Transformer-based LLMs".
- [2025.06.13] Release the refactored code and add support for the Qwen model.
- [2025.03.12] Released the inference code implemented using PyTorch (support for FlashMLA inference requires additional development time).
- [2025.03.04] The four MLA checkpoints (
$d_{kv}$ =8/16/32/128) derived fromSmolLM-135M/360M/1B7
are publicly available. - [2025.03.03] The four MLA checkpoints (
$d_{kv}$ =16/32/64/256) derived fromLlama-2-7B
are publicly available. - [2025.02.21] The paper of MHA2MLA is publicly available: https://arxiv.org/abs/2502.14837
- [2025.02.19] Released the first version of the MHA2MLA code, providing usage code for Llama fine-tuning and evaluating.
-
Provide the code for incorporating the projection matrix and inference. - Thanks to DeepSeek for open-sourcing the FlashMLA inference framework. It’s theoretically possible to save more GPU memory usage using this framework. Let’s see how economical MHA2MLA + FlashMLA (+ KV quanto) can be!
- Release the code of MHA2MLA based on HuggingFace
Transformers
First download the datasets.
- smollm-corpus(fineweb-edu-dedup, cosmopedia-v2, python-edu): https://huggingface.co/datasets/HuggingFaceTB/smollm-corpus
- open-web-math: https://huggingface.co/datasets/open-web-math/open-web-math
- stackoverflow: https://huggingface.co/datasets/bigcode/stackoverflow-clean
Secondly, process the datasets according to https://github.com/huggingface/nanotron/blob/main/docs/nanoset.md.
Install pytorch and other packages.
conda create -n mha2mla python=3.11
pip install torch==2.4.0 torchvision==0.19.0
pip install -r requirements.txt
First, prepare configuration files referencing 135M_4GPU.yaml.
For information on the configuration of mha2mla, you can refer to the arguments.py file.
Then, use the following command for fine-tuning:
torchrun --nproc_per_node 4 \
./src/mha2mla/run_train.py \
--cfg_file ./cfgs/SmolLM1-135M-4GPU.yml
If you want to use the partial-RoPE version
2-norm
, you should get theqk_tensor
first. Using the following command, you can get theqk_tensor
:torchrun --nproc_per_node 1 \ ./src/mha2mla/2_norm.py \ --config_file ./cfgs/SmolLM1-135M-8GPU.yaml \
For evaluation, you can use the following command:
accelerate launch --multi_gpu --num_processes=4 \
./eval/eval.py \
accelerate \
--model_args "pretrained=${model_name_or_path},revision=main,dtype=bfloat16,max_length=2048" \
--override_batch_size 48 \
--custom_tasks "./eval/tasks.py" \
--tasks "./eval/smollm1_base.txt" \
--output_dir "./eval_results/"
For the baseline evaluation, you can use the following command:
torchrun --nproc_per_node=4 \
./eval/longbench.py \
--model_path ${model_name_or_path} \
--tokenizer_path ${model_name_or_path} \
--longbench True \
--lb_max_tokens 2048 \
--lb_batch_size 16 \
--output_dir ./longbench/bf16 \
--dtype "bfloat16"
If you want to use the quantized KV cache, you can use the following command:
torchrun --nproc_per_node=4 \
./eval/longbench.py \
--model_path ${model_name_or_path} \
--tokenizer_path ${model_name_or_path} \
--longbench True \
--lb_max_tokens 2048 \
--lb_batch_size 16 \
--output_dir ./longbench/${model_name_or_path}_hqq_int4 \
--dtype "bfloat16" \
--cache_implementation "quantized" \
--backend "HQQ" \
--nbits 4 \
--residual_length 128
@misc{ji2025economicalinferenceenablingdeepseeks,
title={Towards Economical Inference: Enabling DeepSeek's Multi-Head Latent Attention in Any Transformer-based LLMs},
author={Tao Ji and Bin Guo and Yuanbin Wu and Qipeng Guo and Lixing Shen and Zhan Chen and Xipeng Qiu and Qi Zhang and Tao Gui},
year={2025},
eprint={2502.14837},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2502.14837},
}