Learned token-dropping that prunes up to 80% of context to speed and shrink autoregressive Transformers

May 25, 20237 min

Overview

Decision SnapshotReady For Pilot

Methods were tested on public GPT-2 variants with throughput, memory and accuracy metrics; results are convincing on long-context workloads but limited to fine-tuning experiments and implementation specifics.

Citations8

Evidence Strength0.80

Confidence0.80

Risk Signals10

Trust Signals

Findings with numeric evidence: 5/5

Findings with evidence refs: 5/5

Results with explicit delta: 3/4

Reproducibility

Status: Partial assets available

Open source: Partial

At A Glance

Cost impact: 75%

Production readiness: 70%

Novelty: 60%

Authors

Sotiris Anagnostidis, Dario Pavllo, Luca Biggio, Lorenzo Noci, Aurelien Lucchi, Thomas Hofmann

Links

Abstract / PDF / Data

Why It Matters For Business

You can cut inference memory and often double throughput on long prompts by fine-tuning a small module, reducing costs for batched long-context services.

Who Should Care

Summary TLDR

The paper introduces Adaptively Sparse Attention: a small fine-tuned module that learns, layer-by-layer, to drop past tokens from the autoregressive context. Applied to GPT-2 models, it prunes up to ~80% of the cached tokens with little or no drop in quality, yields large memory savings, and can more than double inference throughput for long contexts. The method uses a differentiable 'α-sigmoid' to train binary drop decisions and an efficient batched key-value cache to erase dropped tokens during generation.

Problem Statement

Standard decoder Transformers pay a quadratic cost in sequence length and keep all past tokens in the key-value cache. This makes long-context inference memory-bound and slow. The paper asks whether we can remove irrelevant past tokens dynamically to save memory and time while preserving model accuracy.

Main Contribution

Adaptively Sparse Attention: a learnable, layerwise mechanism that predicts which past tokens to drop during autoregressive generation.

An α-sigmoid training trick that moves decisions toward binary (keep/drop) while keeping gradients for fine-tuning.

Key Findings

The fine-tuned model can remove ~80% of prior tokens with almost no perplexity loss.

Numbers80.35% sparsity → −0.085 avg perplexity (context=1000) vs dense

Practical UseYou can drop most cached tokens for long prompts and keep language modeling quality in practice.

Evidence RefFig.4; Sec.4.1

Large practical speedups: throughput and latency improve for long contexts.

Numbersup to throughput; GPT-2small +98% throughput (ΔPPL 0.316); GPT-2medium +189% (ΔPPL 0.084)

Practical UseOn long-context workloads you can roughly double tokens/second by fine-tuning with this module.

Evidence RefFig.7; Sec.4.1

Results

MetricValueBaselineDeltaSplit / DatasetEvidenceEvidence Ref
perplexityComparable; −0.085 avg change at 80.35% sparsity (context=1000) vs densedense GPT-2 (no pruning)-0.085validation set over Wikipedia/bookcorpus; averaged across tokensFig.4; Sec.4.1Fig.4
throughput (tokens/sec)Up to increase; GPT-2small +98% (ΔPPL 0.316); GPT-2medium +189% (ΔPPL 0.084)dense GPT-2 counterparts98% and 189% throughput gains reportedcontext sizes up to 1024; measured on NVIDIA RTX A5000Fig.7; Sec.4.1Fig.7

What To Try In 7 Days

Fine-tune a small GPT-2 model with the adaptively sparse module and γ for 25k steps on your domain subset.

Measure KV cache memory and tokens/sec at your target context length, sweep γ for sparsity vs quality.

Implement the simple batched remove/insert cache to test real-world throughput with identical batch prefixes.

Optimization Features

Token Efficiency
learned keep/drop decisions per token
Infra Optimization
enables larger batch sizes by lowering KV memory
Model Optimization
context pruning
System Optimization
data structure to recycle erased token slots
Training Optimization
α-sigmoid schedule for binaryizing decisionsfine-tuning pretrained weights with added interaction params
Inference Optimization
erase entries from key-value cachebatched contiguous memory reuse for KV

Reproducibility

Code AvailableNo
Data AvailableYes
Open Source StatusPartial
LicenseUnknown

Risks & Boundaries

Limitations

Evaluated mainly on GPT-2 family; generalization to other decoders is claimed but not shown.

Requires fine-tuning; freezing base weights gives large perplexity penalties.

When Not To Use

Short-context workloads where pruning overhead outweighs gains.

High-stakes tasks that require access to every past token unless validated.

Failure Modes

Bad α schedule or too-fast α increase leads to poor binary decisions and worse performance.

Tying drop decisions across layers harmed results and made learning unstable for deep models.

Core Entities

Models

GPT-2-smallGPT-2-mediumGPT-2-largeGPT-2-xl

Metrics

perplexityAccuracythroughput (tokens/sec)FLOPsKV cache memorywall-time per generation step

Datasets

English Wikipedia 20220301.enbookcorpus

Benchmarks

WinoGrandeHellaSwagPIQALAMBADA