Cut KV cache by >54% and double throughput with layer-wise 'pyramid' selection

May 21, 20247 min

Overview

Decision SnapshotNeeds Validation

The paper shows consistent memory and throughput gains across models and tasks on A100. Benefits grow with batch size and model scale; extra per-layer selection cost reduces gains for tiny batches.

Citations1

Evidence Strength0.70

Confidence0.80

Risk Signals9

Trust Signals

Findings with numeric evidence: 3/3

Findings with evidence refs: 3/3

Results with explicit delta: 4/4

Reproducibility

Status: Partial assets available

Open source: Partial

At A Glance

Cost impact: 80%

Production readiness: 70%

Novelty: 60%

Authors

Dongjie Yang, XiaoDong Han, Yan Gao, Yao Hu, Shilin Zhang, Hai Zhao

Links

Abstract / PDF / Code

Why It Matters For Business

PyramidInfer lowers GPU memory needs for KV caches and raises throughput, letting you serve larger batches or fewer GPUs for chat workloads and cutting infrastructure cost per token.

Who Should Care

Summary TLDR

PyramidInfer reduces the GPU memory used by transformer KV caches by computing and keeping only a layer-wise subset of 'pivotal' keys/values (PvCs). It uses attention consistency among recent tokens to pick these PvCs in the prefill and generation phases. On A100 80GB experiments, PyramidInfer cuts KV cache by ~54% and doubles or triples throughput compared to standard full-cache serving, with benefits growing for larger batches and larger models. The method adds per-layer selection work, so speedups are smaller at very small batch sizes.

Problem Statement

KV caches (keys and values from attention) can consume multiple times the model size during LLM inference, blocking large-batch, low-cost serving. Prior KV-compression methods only trim already-computed cache and cannot avoid the large prefill memory cost. The paper asks: can we compute and store fewer keys/values up front while keeping generation quality and improving throughput?

Main Contribution

Empirical finding that the number of keys/values crucial for future tokens decreases layer-by-layer; redundancy grows in deeper layers.

Observation that recent tokens attend to a shared set of context keys/values (PvCs) with high overlap; ensemble of recent attention improves selection reliability.

Key Findings

PyramidInfer halves KV cache and doubles throughput on LLaMA 2-13B.

Numbers2.24x throughput; 54.6% KV cache reduction (LLaMA2-13B, A100 80GB).

Practical UseYou can roughly double token throughput and cut KV memory by half on 13B-class models, enabling larger batches or fewer GPUs for the same load.

Evidence RefTable 1

PyramidInfer raises maximum throughput and batch capacity on an 80GB A100.

NumbersMax throughput 1678 t/s (2.8x vs 581) and max batch 88 (vs 42 for baseline) on LLaMA2-13B.

Practical UseIf you want to maximize tokens/sec under a fixed GPU, PyramidInfer lets you push much larger batches and reach 2–3x higher throughput.

Evidence RefTable 2

Results

MetricValueBaselineDeltaSplit / DatasetEvidenceEvidence Ref
KV Mem (LLaMA2-13B, 32 bs, 512+256)11.0 GB (PyramidInfer)24.2 GB (Accelerate)54.6% reductionTable 1 setupMeasured KV cache on A100 80GBTable 1
Throughput (tokens/s, LLaMA2-13B, 32 bs, 512+256)1389 t/s (PyramidInfer)621 t/s (Accelerate)2.24xTable 1 setupMeasured token throughput on A100 80GBTable 1

What To Try In 7 Days

Add PvC selection into your inference stack for one 13B model and compare KV mem and throughput vs current setup.

Use 40% recent-window and power-law decay of PvC lengths as a starting config (paper found this trade-off).

Run a max-batch memory test on an A100-equivalent GPU to quantify cost savings for your workload.

Optimization Features

Token Efficiency
Reduces stored context tokens per layer
Infra Optimization
Enables fitting larger prompts/models on the same GPUReduces OOM incidence for very large models (70B)
System Optimization
Lower KV GPU memory enables larger batch sizesWorks orthogonally with Deepspeed to further boost throughput
Inference Optimization
Layer-wise PvC selection (compute fewer keys/values)Prefill-phase compression (avoid full prompt compute)Sliding recent window and ensemble attention for selectionTop-p selection per-layer with decay (pyramid decay)

Reproducibility

Code AvailableYes
Data AvailableNo
Open Source StatusPartial
LicenseUnknown

Risks & Boundaries

Limitations

Introduces extra per-layer computation (TopP sort), so speedups are limited at small batch sizes.

Not lossless: prefill compression can drop context information if configured too aggressively.

When Not To Use

When single-query, low-batch, ultra-low-latency requests dominate (overhead may outweigh gains).

When you need guaranteed lossless KV retention for exact reproduction.

Failure Modes

Over-aggressive PvC pruning increases perplexity and downstream task errors.

Sort and selection overhead can negate throughput wins with tiny batches.

Core Entities

Models

LLaMA 2-7BLLaMA 2-13BLLaMA 2-34BLLaMA 2-70BVicuna 1.5-16kCodeLLaMA

Metrics

throughput (tokens/s)KV mem (GB)perplexitylatency (ms/token)

Datasets

wikitext-v2MMLUBBHGSM8KHumanEvalMT-BenchLEval

Benchmarks

OpenCompassMT-BenchLEval