Keyformer halves KV cache by keeping only 'key' tokens, doubling token throughput with no fine-tuning

March 14, 20248 min

Overview

Production Readiness

0.8

Novelty Score

0.65

Cost Impact Score

0.8

Citation Count

6

Authors

Muhammad Adnan, Akhil Arunkumar, Gaurav Jain, Prashant J. Nair, Ilya Soloveychik, Purushotham Kamath

Links

Abstract / PDF

Why It Matters For Business

Keyformer cuts memory traffic and latency for long-context generation without retraining, lowering inference cost and enabling higher throughput on existing GPU servers.

Summary TLDR

Keyformer is an inference-only technique that trims the KV cache by keeping a recent token window plus a small set of scored "key" tokens. It uses a Gumbel-based logit regularization and per-layer score accumulation to pick key tokens. On GPT-J, Cerebras-GPT and MPT families Keyformer reduces KV cache up to 50%, cuts KV data movement ~2.9×, lowers latency by 2.1× and raises token throughput up to 2.4× while matching or slightly exceeding baseline ROUGE scores on summarization and holding accuracy on few-shot tasks. The method requires no retraining and the code is released.

Problem Statement

Autoregressive generation stores past keys/values in a KV cache. For long contexts this cache dominates GPU memory bandwidth and latency. Existing system tricks help compute, but not the growing KV cache size; many KV-reduction methods need retraining. We need an inference-time way to shrink KV cache without hurting accuracy.

Main Contribution

A practical inference-time algorithm (Keyformer) that keeps a recent window plus scored key tokens to maintain a fixed KV cache budget without fine-tuning.

A new score function that adds Gumbel logit regularization and a temperature schedule to correct distribution shifts after token removal.

Per-layer accumulation of token scores and design choices (original positions, per-layer scoring, recent-window 20–30%) that preserve accuracy.

Empirical results on GPT-J, Cerebras-GPT, and MPT showing up to 50% KV reduction with 2.1× latency speedup and 2.4× token throughput improvement.

Key Findings

Attention concentrates on a small subset of tokens ("key tokens").

Numbers≈90% of attention mass on ~40% of tokens (Fig.3b)

KV cache reduction yields large runtime wins.

NumbersLatency 2.1× faster, throughput up to 2.4× (50% KV reduction)

Keyformer preserves or slightly improves accuracy versus full attention on evaluated tasks.

NumbersMatches baseline ROUGE at 70% KV; maintains 99% ROUGE for long-context with 50% KV; up to +1.73% ROUGE-2 on some models

Gumbel logit regularization is better than Gaussian/constant/no adjustment for key selection.

NumbersROUGE-2 (GPT-J-6B): Gumbel 19.44 vs Gaussian 14.53 vs none 18.87 (Table 4)

KV cache data movement is the main source of speedup.

NumbersKV data movement reduced ~2.9×; scaled dot product improved ~1.3× (MPT-7B, 50% KV)

Results

Inference latency speedup

Value2.1×

BaselineFull Attention (no KV reduction)

Token generation throughput

Valueup to 2.4×

BaselineFull Attention

KV cache data movement reduction

Value≈2.9×

BaselineFull Attention KV movement

Scaled dot-product compute improvement

Value≈1.3×

BaselineFull Attention

Accuracy

ValueMeets MLPerf 99% target with 70% KV; maintains ~99% at 50% for long-context MPT

BaselineFull Attention ROUGE

Who Should Care

What To Try In 7 Days

Clone the Keyformer repo and run the provided example on an MPT/GPT-J checkpoint.

Measure tokens/sec and ROUGE on your summarization prompt with KV budgets at 50%, 70%, and full.

Set recent-window w to 20–30% and use τ schedule from 1→2; compare per-layer vs shared scoring.

Optimization Features

Token Efficiency

  • Increases tokens/sec up to 2.4×
  • Reduces per-token KV transfer

Infra Optimization

  • Works on existing GPU stacks (A100) without model retraining

System Optimization

  • Reduces off-chip KV data movement ~2.9×
  • Enables larger batch sizes under same GPU memory

Inference Optimization

  • KV cache reduction via key-token selection
  • Mixed recent-window + key-token attention
  • Gumbel-softmax logit regularization for scoring
  • Per-layer score accumulation
  • Static KV cache budget to control memory

Reproducibility

Code Available

Data Available

Open Source Status

  • partial

Risks & Boundaries

Limitations

  • Quality depends on model, task, and positional encoding; tune per model.
  • Gumbel softmax scoring adds overhead and must be balanced against KV savings.
  • Does not eliminate all off-chip memory needs; extreme long contexts still need additional system-level solutions.

When Not To Use

  • Short-context generation where KV cache is small and savings are minimal.
  • Workflows that cannot modify the inference pipeline or scoring step.
  • Use-cases that require exact, unmodified attention for correctness guarantees without validation.

Failure Modes

  • Mis-identifying key tokens reduces accuracy if score function or τ schedule is poorly tuned.
  • Performance gains shrink on compute-bound workloads where KV movement is not the bottleneck.
  • Positional-encoding differences can change token importance patterns and cause drops.

Core Entities

Models

  • GPT-J-6B
  • Cerebras-GPT-6.7B
  • MPT-7B
  • MPT-7B-storywriter

Metrics

  • ROUGE-1
  • ROUGE-2
  • ROUGE-L
  • tokens/sec
  • inference latency
  • KV cache data movement

Datasets

  • CNN/DailyMail
  • GovReport
  • SODA
  • lm-eval-harness (PIQA, Winogrande, OpenBookQA, COPA)

Benchmarks

  • Accuracy