Overview
Production Readiness
0.8
Novelty Score
0.65
Cost Impact Score
0.8
Citation Count
6
Why It Matters For Business
Keyformer cuts memory traffic and latency for long-context generation without retraining, lowering inference cost and enabling higher throughput on existing GPU servers.
Summary TLDR
Keyformer is an inference-only technique that trims the KV cache by keeping a recent token window plus a small set of scored "key" tokens. It uses a Gumbel-based logit regularization and per-layer score accumulation to pick key tokens. On GPT-J, Cerebras-GPT and MPT families Keyformer reduces KV cache up to 50%, cuts KV data movement ~2.9×, lowers latency by 2.1× and raises token throughput up to 2.4× while matching or slightly exceeding baseline ROUGE scores on summarization and holding accuracy on few-shot tasks. The method requires no retraining and the code is released.
Problem Statement
Autoregressive generation stores past keys/values in a KV cache. For long contexts this cache dominates GPU memory bandwidth and latency. Existing system tricks help compute, but not the growing KV cache size; many KV-reduction methods need retraining. We need an inference-time way to shrink KV cache without hurting accuracy.
Main Contribution
A practical inference-time algorithm (Keyformer) that keeps a recent window plus scored key tokens to maintain a fixed KV cache budget without fine-tuning.
A new score function that adds Gumbel logit regularization and a temperature schedule to correct distribution shifts after token removal.
Per-layer accumulation of token scores and design choices (original positions, per-layer scoring, recent-window 20–30%) that preserve accuracy.
Empirical results on GPT-J, Cerebras-GPT, and MPT showing up to 50% KV reduction with 2.1× latency speedup and 2.4× token throughput improvement.
Key Findings
Attention concentrates on a small subset of tokens ("key tokens").
KV cache reduction yields large runtime wins.
Keyformer preserves or slightly improves accuracy versus full attention on evaluated tasks.
Gumbel logit regularization is better than Gaussian/constant/no adjustment for key selection.
KV cache data movement is the main source of speedup.
Results
Inference latency speedup
Token generation throughput
KV cache data movement reduction
Scaled dot-product compute improvement
Accuracy
Who Should Care
What To Try In 7 Days
Clone the Keyformer repo and run the provided example on an MPT/GPT-J checkpoint.
Measure tokens/sec and ROUGE on your summarization prompt with KV budgets at 50%, 70%, and full.
Set recent-window w to 20–30% and use τ schedule from 1→2; compare per-layer vs shared scoring.
Optimization Features
Token Efficiency
- Increases tokens/sec up to 2.4×
- Reduces per-token KV transfer
Infra Optimization
- Works on existing GPU stacks (A100) without model retraining
System Optimization
- Reduces off-chip KV data movement ~2.9×
- Enables larger batch sizes under same GPU memory
Inference Optimization
- KV cache reduction via key-token selection
- Mixed recent-window + key-token attention
- Gumbel-softmax logit regularization for scoring
- Per-layer score accumulation
- Static KV cache budget to control memory
Reproducibility
Data Urls
Code Available
Data Available
Open Source Status
- partial
Risks & Boundaries
Limitations
- Quality depends on model, task, and positional encoding; tune per model.
- Gumbel softmax scoring adds overhead and must be balanced against KV savings.
- Does not eliminate all off-chip memory needs; extreme long contexts still need additional system-level solutions.
When Not To Use
- Short-context generation where KV cache is small and savings are minimal.
- Workflows that cannot modify the inference pipeline or scoring step.
- Use-cases that require exact, unmodified attention for correctness guarantees without validation.
Failure Modes
- Mis-identifying key tokens reduces accuracy if score function or τ schedule is poorly tuned.
- Performance gains shrink on compute-bound workloads where KV movement is not the bottleneck.
- Positional-encoding differences can change token importance patterns and cause drops.
Core Entities
Models
- GPT-J-6B
- Cerebras-GPT-6.7B
- MPT-7B
- MPT-7B-storywriter
Metrics
- ROUGE-1
- ROUGE-2
- ROUGE-L
- tokens/sec
- inference latency
- KV cache data movement
Datasets
- CNN/DailyMail
- GovReport
- SODA
- lm-eval-harness (PIQA, Winogrande, OpenBookQA, COPA)
Benchmarks
- Accuracy

