Learned token-dropping that prunes up to 80% of context to speed and shrink autoregressive Transformers

May 25, 20237 min

Overview

Production Readiness

0.7

Novelty Score

0.6

Cost Impact Score

0.75

Citation Count

8

Authors

Sotiris Anagnostidis, Dario Pavllo, Luca Biggio, Lorenzo Noci, Aurelien Lucchi, Thomas Hofmann

Links

Abstract / PDF

Why It Matters For Business

You can cut inference memory and often double throughput on long prompts by fine-tuning a small module, reducing costs for batched long-context services.

Summary TLDR

The paper introduces Adaptively Sparse Attention: a small fine-tuned module that learns, layer-by-layer, to drop past tokens from the autoregressive context. Applied to GPT-2 models, it prunes up to ~80% of the cached tokens with little or no drop in quality, yields large memory savings, and can more than double inference throughput for long contexts. The method uses a differentiable 'α-sigmoid' to train binary drop decisions and an efficient batched key-value cache to erase dropped tokens during generation.

Problem Statement

Standard decoder Transformers pay a quadratic cost in sequence length and keep all past tokens in the key-value cache. This makes long-context inference memory-bound and slow. The paper asks whether we can remove irrelevant past tokens dynamically to save memory and time while preserving model accuracy.

Main Contribution

Adaptively Sparse Attention: a learnable, layerwise mechanism that predicts which past tokens to drop during autoregressive generation.

An α-sigmoid training trick that moves decisions toward binary (keep/drop) while keeping gradients for fine-tuning.

A batched key-value cache data structure that supports efficient deletion and contiguous memory reuse for faster inference.

Empirical evaluation on GPT-2 family showing up to ~80% context pruning with minimal quality loss and large throughput/memory gains.

Key Findings

The fine-tuned model can remove ~80% of prior tokens with almost no perplexity loss.

Numbers80.35% sparsity → −0.085 avg perplexity (context=1000) vs dense

Large practical speedups: throughput and latency improve for long contexts.

Numbersup to 2× throughput; GPT-2small +98% throughput (ΔPPL 0.316); GPT-2medium +189% (ΔPPL 0.084)

Cached-memory (KV) requirements fall roughly linearly with sparsity, reducing per-step latency.

Numbersup to 50% wall-time latency reduction per generation step for large contexts

Training only the new interaction parameters is not enough.

NumbersFreezing original weights and training only interaction params → +9.285 perplexity

Dropped tokens are often punctuation and local words.

NumbersHigh drop rates on punctuation and stop-word POS categories

Results

perplexity

ValueComparable; −0.085 avg change at 80.35% sparsity (context=1000) vs dense

Baselinedense GPT-2 (no pruning)

throughput (tokens/sec)

ValueUp to 2× increase; GPT-2small +98% (ΔPPL 0.316); GPT-2medium +189% (ΔPPL 0.084)

Baselinedense GPT-2 counterparts

wall-time per generation step

ValueUp to ~50% reduction for large contexts

Baselinedense GPT-2 generation

interpretability (pruned token types)

ValueMost pruning triggered by punctuation and local function words

Baselinen/a

Who Should Care

What To Try In 7 Days

Fine-tune a small GPT-2 model with the adaptively sparse module and γ for 25k steps on your domain subset.

Measure KV cache memory and tokens/sec at your target context length, sweep γ for sparsity vs quality.

Implement the simple batched remove/insert cache to test real-world throughput with identical batch prefixes.

Optimization Features

Token Efficiency

  • learned keep/drop decisions per token

Infra Optimization

  • enables larger batch sizes by lowering KV memory

Model Optimization

  • context pruning

System Optimization

  • data structure to recycle erased token slots

Training Optimization

  • α-sigmoid schedule for binaryizing decisions
  • fine-tuning pretrained weights with added interaction params

Inference Optimization

  • erase entries from key-value cache
  • batched contiguous memory reuse for KV

Reproducibility

Data Available

Open Source Status

  • partial

Risks & Boundaries

Limitations

  • Evaluated mainly on GPT-2 family; generalization to other decoders is claimed but not shown.
  • Requires fine-tuning; freezing base weights gives large perplexity penalties.
  • Extra compute and memory needed for interaction keys K_int; trade-off depends on r.
  • Bathed gains reduce when dropped-token patterns create 'holes' across different samples in a batch.

When Not To Use

  • Short-context workloads where pruning overhead outweighs gains.
  • High-stakes tasks that require access to every past token unless validated.
  • Scenarios where you cannot afford to fine-tune base model weights.

Failure Modes

  • Bad α schedule or too-fast α increase leads to poor binary decisions and worse performance.
  • Tying drop decisions across layers harmed results and made learning unstable for deep models.
  • Freezing original parameters and training only the drop module causes large perplexity increases.

Core Entities

Models

  • GPT-2-small
  • GPT-2-medium
  • GPT-2-large
  • GPT-2-xl

Metrics

  • perplexity
  • Accuracy
  • throughput (tokens/sec)
  • FLOPs
  • KV cache memory
  • wall-time per generation step

Datasets

  • English Wikipedia 20220301.en
  • bookcorpus

Benchmarks

  • WinoGrande
  • HellaSwag
  • PIQA
  • LAMBADA