Fine-tune existing MHA LLMs to DeepSeek MLA for up to ~97% KV-cache savings with 0.6–1% data

February 20, 20257 min

Overview

Production Readiness

0.7

Novelty Score

0.6

Cost Impact Score

0.8

Citation Count

0

Authors

Tao Ji, Bin Guo, Yuanbin Wu, Qipeng Guo, Lixing Shen, Zhan Chen, Xipeng Qiu, Qi Zhang, Tao Gui

Links

Abstract / PDF

Why It Matters For Business

MHA2MLA lets teams cut KV cache memory by ~90%+ and keep near-original quality, lowering GPU RAM needs and cost for long-context inference while requiring only tiny fine-tuning budgets.

Summary TLDR

The paper introduces MHA2MLA, a fine-tuning recipe that converts pre-trained Multi-Head Attention (MHA) LLMs into DeepSeek’s Multi-Head Latent Attention (MLA). It uses two key moves: partial-RoPE (remove positional rotations in low-impact dimensions) and joint SVD factorization of the keys/values to build a compact latent KV cache. Across models from 135M to 13B, MHA2MLA recovers performance using only 0.6%–1% of pretraining tokens, reduces KV cache size by up to 96.87% (when combined with Int4 quantization), and keeps quality losses small (example: ~1% LongBench drop for 7B), making it practical to retrofit many deployed models for cheaper long-context inference.

Problem Statement

KV cache memory during autoregressive inference grows with sequence length and blocks long-context use. DeepSeek's MLA compresses KV into a latent space but is architecturally different from standard MHA, so converting well-trained MHA models to MLA without retraining from scratch is hard and costly. The paper asks: can we adapt pretrained MHA/GQA models to MLA cheaply and without large data?

Main Contribution

MHA2MLA: a data-efficient full-parameter fine-tuning pipeline to migrate pretrained MHA/GQA models to MLA.

Partial-RoPE: remove RoPE from selected dimensions using a contribution-aware (S_2-norm) selection.

SVDjoint: joint SVD factorization of keys and values to build a shared low-rank latent KV representation.

Demonstration across five scales (135M–13B) that fine-tuning needs only 0.6%–1% of pretraining tokens.

Show compatibility with KV-cache quantization (e.g., Int4) for compound memory savings up to ~96.87%.

Key Findings

MHA2MLA adapts pretrained MHA/GQA models using a tiny fraction of data.

Numbers0.6%–1% of pretraining tokens used for fine-tuning

KV cache size can be cut dramatically while keeping quality near baseline.

NumbersUp to -96.87% KV cache (d_kv=16 + Int4) with ~-2.4% LongBench drop

Contribution-aware RoPE selection + joint SVD is the strongest configuration.

NumbersSVDjoint > SVDsplit by +0.91% (135M) and +1.38% (1B7) average

Larger models tolerate stronger compression with less quality loss.

NumbersCompression to 18.75% => drops: 135M -2.24%, 7B -0.30%, 13B -0.23%

Some RoPE removal strategies cause convergence failures or big drops.

NumbersS_low strategy: -5.25% (135M) and -1.87% (1B7) drops; higher convergence risk

Results

Fine-tuning data fraction

Value0.6%–1% of pretraining tokens

Baselinepretraining tokens

KV cache reduction (Llama2-7B)

Value-92.19% (d_kv=64 + Int4 HQQ)

BaselineBF16 baseline

Max KV cache reduction (combined + quant)

Value-96.87% (d_kv=16 + Int4 Quanto)

BaselineBF16 baseline

LongBench quality (7B)

ValueMHA BF16: 27.4 → MHA2MLA d_kv=64+Int4: 26.4

BaselineBF16 MHA 27.4

Commonsense avg change (7B Llama2, d_kv=32)

Value59.50 → 59.20

Baselineoriginal MHA

Who Should Care

What To Try In 7 Days

Run S_2-norm partial-RoPE + SVDjoint on a 7B checkpoint and measure KV cache size vs baseline.

Try d_kv=64 (moderate) and compare LongBench or your long-context task before quantization.

If stable, combine with Int4 KV quantization and profile memory/latency and quality trade-offs.

Optimization Features

Token Efficiency

  • Reduces memory per token during long-context decoding

Infra Optimization

  • Lower GPU memory and potential for fewer GPUs per request

Model Optimization

  • Low-rank KV projection (SVDjoint)

System Optimization

  • Matrix merging for NoPE part to reduce inference ops

Training Optimization

  • Data-efficient full-parameter fine-tuning (0.6%–1% data)

Inference Optimization

  • KV cache compression into latent vectors (MLA)
  • Compatible with Int4 KV quantization

Reproducibility

Code Available

Data Available

Open Source Status

  • partial

Risks & Boundaries

Limitations

  • Validated on models up to 13B; not tested systematically on very large models (e.g., Llama3) due to compute limits.
  • Depends on access to model weights and ability to run full-parameter fine-tuning.
  • DeepSeek’s MLA tensor-parallel inference framework is not open-sourced, limiting deployment testing for >7B models.

When Not To Use

  • If you cannot fine-tune model weights (no access or policy restrictions).
  • When absolute no-drop quality is required for small models under extreme compression.
  • If you lack the infra to run SVD-based matrix operations during conversion.

Failure Modes

  • Aggressive RoPE removal using low-frequency retention (S_low) can cause convergence failure and large quality drops.
  • Very aggressive quantization (2-bit) can collapse generation quality despite KV reduction.
  • Approximation errors from SVDsplit (vs SVDjoint) reduce accuracy more on some sizes.

Core Entities

Models

  • Llama2-7B
  • Llama2-13B
  • SmolLM-135M
  • SmolLM-360M
  • SmolLM-1B7

Metrics

  • KV cache memory reduction (%)
  • Accuracy
  • LongBench average score
  • Fine-tuning tokens as fraction of pretraining tokens

Datasets

  • LongBench
  • MMLU
  • ARC
  • PIQA
  • HellaSwag
  • OpenBookQA
  • Winogrande
  • SmolLM pretraining corpus (fineweb-edu-dedup, cosmopedia-v2, python-edu, open-web-math, StackOverflo

Benchmarks

  • LongBench
  • Commonsense reasoning suite (MMLU, ARC, PIQA, HellaSwag, OBQA, Winogrande)