Train a tiny 'judge' on top of target embeddings to accept many more draft tokens and speed up large-model generation up to ~9× without loss

January 31, 20259 min

Overview

Production Readiness

0.7

Novelty Score

0.6

Cost Impact Score

0.8

Citation Count

0

Authors

Gregor Bachmann, Sotiris Anagnostidis, Albert Pumarola, Markos Georgopoulos, Artsiom Sanakoyeu, Yuming Du, Edgar Schönfeld, Ali Thabet, Jonas Kohler

Links

Abstract / PDF

Why It Matters For Business

You can reduce latency and cost of serving very large LLMs by accepting more draft tokens using a tiny judge head. It's cheap to train and deploy and gives large throughput gains in optimized runtimes without retraining the big model.

Summary TLDR

Standard speculative decoding (SD) rejects many objectively correct draft tokens because it enforces alignment with the target model. The authors train a tiny linear 'judge' on top of target-model token embeddings to predict whether a draft token is contextually correct. Judge decoding accepts ~3× more tokens than standard SD (e.g., mean accepted tokens from ~6.3→~19.7 for 8B→405B), giving up to 9.7× speedup in common frameworks and 129 tokens/s in optimized inference, while largely preserving benchmark accuracy. The judge is cheap (≈16.4k params, trainable in <1.5 hr on 30k tokens). OOD tasks and small draft/target models reduce gains; safety and guarantee-to-match-target are lost.

Problem Statement

Speculative decoding speeds up autoregressive generation by using a fast draft model and verifying proposed tokens with the target model. Current verification rejects many valid draft tokens because it requires high alignment with the target's own probabilities. This limits the number of draft tokens and the achievable speedup, even when drafts are high quality.

Main Contribution

Showed that logits-based verification in standard speculative decoding rejects many correct tokens, limiting speedups even with high-quality drafts (GPT-4o, humans).

Proposed 'judge decoding': a small linear classifier on top of target model token embeddings that predicts token correctness and augments standard verification.

Demonstrated large practical speedups (up to 9.7× in HuggingFace and 129 tokens/s in gpt-fast) while largely preserving accuracy on common benchmarks.

Released a carefully curated 500-example dataset of (prompt, correct answer, wrong answer) tuples used to train the judge head (no public URL provided).

Key Findings

Judge decoding increases average accepted tokens from ~6.3 to ~19.7 for Llama-8B draft → Llama-405B target.

Numbersm* 8B/405B: 6.3 → 19.7 (Table 1)

Judge decoding yields large end-to-end speedups and throughput in practice.

Numbers8B/405B: 9.7× (HuggingFace), 129.3 tokens/s (gpt-fast); 8B/70B: 2×/3× and 141.8 tok/s (Table 1)

The judge head is extremely small and cheap to train.

Numbers16.4k parameters, trained on ~30k tokens in <1.5 hours

Standard verification does not reward higher-quality drafts: GPT-4o drafts are accepted only ~2 tokens on average under standard SD.

NumbersAverage acceptance length ≈ 2 tokens for GPT-4o → Llama-405B (Fig. 3, Sec. 3.2)

Judge decoding largely preserves target accuracy on evaluated benchmarks, but OOD gaps exist.

NumbersOut-of-distribution: HumanEval accuracy fell 86.6% → 80.4% when judge trained without coding examples; still above draft

Results

mean accepted tokens (m*)

Value8B/405B: 19.7 (judge) vs 6.3 (standard)

Baselinestandard speculative decoding

end-to-end speedup (HuggingFace)

Value8B/405B-JUDGE: 9.7×

Baselinestandard decoding (no-speculation baseline)

throughput (tokens/s, optimized runtime)

Value8B/405B-JUDGE: 129.3 tok/s; 8B/70B-JUDGE: 141.8 tok/s

Baselinestandard decoding rows in Table 1

training cost of judge head

Value16.4k parameters; trained on ~30k tokens in <1.5 hours

Baselineno judge

OOD robustness (coding removed from training)

ValueHumanEval accuracy: 86.6% → 80.4% after removing coding examples from judge training

Baselinejudge trained on full data (86.6%)

Who Should Care

What To Try In 7 Days

Train a linear judge head on 500 curated (prompt, correct, wrong) examples from your domain and test acceptance rates.

Measure m* and tokens/s on your deployment stack (gpt-fast or Triton) and compare to current decoding.

Run safety spot checks: measure whether judge accepts problematic draft outputs and add guarded negative examples.

Optimization Features

Token Efficiency

  • Higher tokens accepted per target call (~3× increase)

Infra Optimization

  • Multi-GPU H100 deployment

System Optimization

  • Optimized runtime evaluation (gpt-fast)
  • Use with 8-bit quantization

Inference Optimization

  • Speculative decoding with judge verification
  • Parallel token verification

Reproducibility

Open Source Status

  • no

Risks & Boundaries

Limitations

  • Loses the formal guarantee to exactly match target-model output when judge accepts tokens.
  • Requires a reasonably high-quality draft model; poor drafters reduce benefit.
  • Target must be large enough to exhibit corrective embedding signals; gains are smaller for small targets.
  • New task families require curated annotations for the judge to avoid accuracy loss.
  • Potential safety risk: judge could accept safety-critical content if drafts contain it.

When Not To Use

  • When you need provable parity with target outputs (lossless guarantees).
  • When the draft model is weak and produces many incorrect tokens.
  • For very small target models where embedding corrections are weak.

Failure Modes

  • Judge false positives: accepting mistaken tokens and degrading output.
  • Reduced accuracy on tasks far from judge training data (OOD).
  • Safety failures if drafts contain unsafe content and judge accepts them.

Core Entities

Models

  • Llama-3.1-8B
  • Llama-3.1-70B
  • Llama-3.1-405B
  • GPT-4o
  • GPT-4o-mini
  • Phi-3-mini
  • Eagle-2
  • Medusa

Metrics

  • accepted tokens (m*)
  • tokens per second
  • Accuracy
  • speedup factor

Datasets

  • MT-Bench
  • GSM8K
  • HumanEval
  • ARC
  • MMLU
  • Alpaca (inputs only)
  • wikipedia-summary (subset)

Benchmarks

  • MT-Bench
  • GSM8K
  • HumanEval
  • ARC
  • MMLU

Context Entities

Models

  • Mistral-Large-2
  • Llama-8B (drafter in some experiments)

Datasets

  • Stanford Alpaca (inputs filtered)
  • ARC (inputs filtered)