Compress KV cache to sub-4-bit with <0.1 PPL loss and enable million‑to‑10M token inference

January 31, 20248 min

Overview

Production Readiness

0.7

Novelty Score

0.7

Cost Impact Score

0.8

Citation Count

9

Authors

Coleman Hooper, Sehoon Kim, Hiva Mohammadzadeh, Michael W. Mahoney, Yakun Sophia Shao, Kurt Keutzer, Amir Gholami

Links

Abstract / PDF

Why It Matters For Business

Cut KV cache memory 3–7× and preserve accuracy so you can serve much longer contexts on existing GPUs, reducing infrastructure cost or enabling new long-document features.

Summary TLDR

KVQuant is a practical recipe to quantize the KV cache (stored Keys and Values) so LLMs can run with very long contexts. It combines: per-channel Key quantization done before RoPE (rotary embeddings), a sensitivity-weighted non-uniform datatype (nuqX), per-vector dense-and-sparse outlier handling, attention-sink-aware retention of the first token, and custom CUDA kernels. On LLaMA-family and Mistral models KVQuant (nuq3 + 1% outliers) keeps perplexity within ~+0.07 on Wikitext-2 while cutting KV cache memory ~4.8×; kernels also report up to ~1.7× speedups vs fp16 matvecs. The method enables LLaMA-7B with 1M tokens on a single A100 and 10M tokens on 8 GPUs.

Problem Statement

For long-context inference the KV cache (stored Key/Value activations) dominates GPU memory and bandwidth. Existing activation quantization breaks at ultra-low bits (<4-bit) because outliers, channel structure, and RoPE rotations skew quantization ranges. The field needs a practical low-bit KV cache quantization method that keeps accuracy and reduces memory/bandwidth.

Main Contribution

Per-channel Key quantization applied before RoPE to align with Key outlier channels and avoid RoPE-induced mixing.

nuqX: per-layer sensitivity-weighted non-uniform datatypes computed offline to place quantization signposts where they matter.

Per-vector dense-and-sparse quantization: detect per-channel (Keys) or per-token (Values) outliers and store them sparsely (e.g., 1% outliers).

Attention sink-aware rule: keep first token in fp16 to reduce strong sensitivity in early layers.

Mix of offline calibration (Keys) and efficient online computations (Value scaling & outlier thresholds).

Custom CUDA kernels (LUT-based 4-bit, balanced sparse matvec) and CSR/CSC storage to speed dequantize-and-multiply.

Key Findings

3-bit KV cache with 1% sparse outliers keeps perplexity near fp16 on Wikitext-2

NumbersLLaMA-7B PPL 5.75 vs fp16 5.68 (+0.07)

KV cache memory reduced roughly 4.8× at 3-bit

NumbersLLaMA-7B KV cache 64.0GB → 13.3GB (≈4.8×)

Tiny fraction of outliers (≈1%) recovers most accuracy

NumbersRemoving 1% outliers cuts PPL by an extra 0.19 (3-bit LLaMA-7B)

Enables extreme context lengths: 1M (single A100) and 10M (8 GPUs)

NumbersLLaMA-7B nuq2 -> 1M on 1×A100, 10M on 8×GPU system

Custom kernels speed up matvecs relative to fp16

NumbersUp to ≈1.7× latency improvement on A6000 for dense+1% sparse kernels

Results

Perplexity (LLaMA-7B, Wikitext-2)

Value5.75 (nuq3-1%)

Baseline5.68 (fp16)

KV cache memory (LLaMA-7B, seqlen 128K)

Value13.3 GB (nuq3-1%)

Baseline64.0 GB (fp16)

Kernel speedup (matvec latency)

Valueup to ≈1.7× faster

Baselinefp16 matvec

Passkey retrieval success

Valuenear 1.0 success at many lengths (nuq3-1%)

Baselinefp16

Who Should Care

What To Try In 7 Days

Run nuq3-1% KVQuant calibration on your model with 16 calibration samples and test Wikitext-like perplexity.

Keep the first token in fp16 (attention-sink-aware) and extract ~1% outliers per-vector to observe big accuracy gains.

Integrate the provided CUDA kernels (or implement LUT-based dequantization + CSR/CSC outlier storage) to measure latency vs fp16 on your GPU.

Optimization Features

Token Efficiency

  • enables longer token windows with same GPUs

Infra Optimization

  • reduce multi-GPU memory needs for long contexts
  • optionally run topk outlier detection on CPU in parallel

Model Optimization

  • non-uniform datatypes (nuqX)
  • per-channel Key quantization

System Optimization

  • custom CUDA kernels with LUT dequantize
  • balanced sparse matvec kernel
  • CSR/CSC sparse layout for outliers

Inference Optimization

  • per-token Value quantization
  • pre-RoPE Key quantization
  • per-vector dense-and-sparse outlier extraction
  • attention-sink-aware retention of first token
  • offline per-layer calibration for Keys
  • online per-token outlier/scale computation for Values

Reproducibility

Code Available

Data Available

Open Source Status

  • yes

Risks & Boundaries

Limitations

  • Work targets inference only; does not solve training long-context (>100k) issues.
  • Latency benchmarks focus on memory-bandwidth-bound generation, not batched prefill/prompt compression.
  • Current end-to-end implementation has inefficiencies in sparse matrix memory allocation and appends.

When Not To Use

  • When you require full-precision activations for downstream tasks (e.g., precise numeric outputs).
  • If you cannot modify inference kernels or add custom CUDA implementations.
  • If you need an out-of-the-box solution for models without available calibration data or gradients.

Failure Modes

  • Very low-bit (2-bit) without proper outlier extraction can cause large perplexity regressions.
  • Incorrect offline calibration for Keys may hurt accuracy if outliers are not handled.
  • Sparse storage and append overheads can negate memory/latency gains if not implemented carefully.

Core Entities

Models

  • LLaMA
  • Llama-2
  • Llama-3
  • Mistral

Metrics

  • perplexity
  • passkey retrieval success rate
  • latency (microseconds)
  • KV cache size (GB)

Datasets

  • Wikitext-2
  • C4
  • LongBench
  • RULER
  • passkey retrieval benchmark

Benchmarks

  • Wikitext-2 perplexity
  • C4 perplexity
  • LongBench
  • RULER
  • passkey retrieval