Train a tiny 'judge' on top of target embeddings to accept many more draft tokens and speed up large-model generation up to ~9× without loss

January 31, 20259 min

Overview

Decision SnapshotNeeds Validation

Results show large speedups in optimized runtimes and preserved benchmark accuracy, but the guarantee to match target output is lost and OOD performance drops unless the judge is trained on similar data.

Citations0

Evidence Strength0.70

Confidence0.80

Risk Signals11

Trust Signals

Findings with numeric evidence: 5/5

Findings with evidence refs: 5/5

Results with explicit delta: 5/5

Reproducibility

Status: No open assets linked

Open source: No

At A Glance

Cost impact: 80%

Production readiness: 70%

Novelty: 60%

Authors

Gregor Bachmann, Sotiris Anagnostidis, Albert Pumarola, Markos Georgopoulos, Artsiom Sanakoyeu, Yuming Du, Edgar Schönfeld, Ali Thabet, Jonas Kohler

Links

Abstract / PDF

Why It Matters For Business

You can reduce latency and cost of serving very large LLMs by accepting more draft tokens using a tiny judge head. It's cheap to train and deploy and gives large throughput gains in optimized runtimes without retraining the big model.

Who Should Care

Summary TLDR

Standard speculative decoding (SD) rejects many objectively correct draft tokens because it enforces alignment with the target model. The authors train a tiny linear 'judge' on top of target-model token embeddings to predict whether a draft token is contextually correct. Judge decoding accepts ~3× more tokens than standard SD (e.g., mean accepted tokens from ~6.3→~19.7 for 8B→405B), giving up to 9.7× speedup in common frameworks and 129 tokens/s in optimized inference, while largely preserving benchmark accuracy. The judge is cheap (≈16.4k params, trainable in <1.5 hr on 30k tokens). OOD tasks and small draft/target models reduce gains; safety and guarantee-to-match-target are lost.

Problem Statement

Speculative decoding speeds up autoregressive generation by using a fast draft model and verifying proposed tokens with the target model. Current verification rejects many valid draft tokens because it requires high alignment with the target's own probabilities. This limits the number of draft tokens and the achievable speedup, even when drafts are high quality.

Main Contribution

Showed that logits-based verification in standard speculative decoding rejects many correct tokens, limiting speedups even with high-quality drafts (GPT-4o, humans).

Proposed 'judge decoding': a small linear classifier on top of target model token embeddings that predicts token correctness and augments standard verification.

Key Findings

Judge decoding increases average accepted tokens from ~6.3 to ~19.7 for Llama-8B draft → Llama-405B target.

Numbersm* 8B/405B: 6.319.7 (Table 1)

Practical UseYou can accept ~3× more draft tokens per target call, so draft models can be used to produce longer chunks and reduce target calls.

Evidence RefTable 1, Sec. 3.1–3.2

Judge decoding yields large end-to-end speedups and throughput in practice.

Numbers8B/405B: 9.7× (HuggingFace), 129.3 tokens/s (gpt-fast); 8B/70B: / and 141.8 tok/s (Table 1)

Practical UseDeploy judge decoding to get major latency wins on large targets, especially in optimized runtimes and large-model regimes.

Evidence RefTable 1, Sec. 5.2

Results

MetricValueBaselineDeltaSplit / DatasetEvidenceEvidence Ref
mean accepted tokens (m*)8B/405B: 19.7 (judge) vs 6.3 (standard)standard speculative decoding+13.4 tokensmixed benchmarks (MT-Bench, GSM8K, HumanEval)Table 1 reports m* values for standard and judge verificationTable 1
end-to-end speedup (HuggingFace)8B/405B-JUDGE: 9.7×standard decoding (no-speculation baseline)≈+8.7×batch size 1Table 1 HuggingFace columnTable 1

What To Try In 7 Days

Train a linear judge head on 500 curated (prompt, correct, wrong) examples from your domain and test acceptance rates.

Measure m* and tokens/s on your deployment stack (gpt-fast or Triton) and compare to current decoding.

Run safety spot checks: measure whether judge accepts problematic draft outputs and add guarded negative examples.

Optimization Features

Token Efficiency
Higher tokens accepted per target call (~3× increase)
Infra Optimization
Multi-GPU H100 deployment
System Optimization
Optimized runtime evaluation (gpt-fast)Use with 8-bit quantization
Inference Optimization
Speculative decoding with judge verificationParallel token verification

Reproducibility

Code AvailableNo
Data AvailableNo
Open Source StatusNo
LicenseUnknown

Risks & Boundaries

Limitations

Loses the formal guarantee to exactly match target-model output when judge accepts tokens.

Requires a reasonably high-quality draft model; poor drafters reduce benefit.

When Not To Use

When you need provable parity with target outputs (lossless guarantees).

When the draft model is weak and produces many incorrect tokens.

Failure Modes

Judge false positives: accepting mistaken tokens and degrading output.

Reduced accuracy on tasks far from judge training data (OOD).

Core Entities

Models

Llama-3.1-8BLlama-3.1-70BLlama-3.1-405BGPT-4oGPT-4o-miniPhi-3-miniEagle-2Medusa

Metrics

accepted tokens (m*)tokens per secondAccuracyspeedup factor

Datasets

MT-BenchGSM8KHumanEvalARCMMLUAlpaca (inputs only)wikipedia-summary (subset)

Benchmarks

MT-BenchGSM8KHumanEvalARCMMLU

Context Entities

Models

Mistral-Large-2Llama-8B (drafter in some experiments)

Datasets

Stanford Alpaca (inputs filtered)ARC (inputs filtered)