Predict problem difficulty from LLM mid-layer embeddings and route each query to the smallest model likely to solve it, cutting compute with

November 5, 20256 min

Overview

Decision SnapshotNeeds Validation

Shows practical routing gains on multiple public math datasets. Results are promising but reported mainly on reasoning/math tasks and summarized by figures rather than exhaustive numeric tables.

Citations0

Evidence Strength0.60

Confidence0.70

Risk Signals9

Trust Signals

Findings with numeric evidence: 1/3

Findings with evidence refs: 3/3

Results with explicit delta: 3/3

Reproducibility

Status: Partial assets available

Open source: Partial

At A Glance

Cost impact: 70%

Production readiness: 60%

Novelty: 50%

Authors

Bo Zhao, Berkcan Kapusuzoglu, Kartik Balasubramaniam, Sambit Sahu, Supriyo Chakraborty, Genta Indra Winata

Links

Abstract / PDF

Why It Matters For Business

Routing saves inference cost by sending easy queries to cheaper models. That lowers cloud bills and lets you scale reasoning services while keeping top-model accuracy.

Who Should Care

Summary TLDR

Train small classifiers on intermediate LLM embeddings to predict problem difficulty or a model's chance of success. Use those predictions to route each problem to the smallest model likely to solve it. On mixed math benchmarks the router matches the big model's accuracy while using about two-thirds of its inference compute.

Problem Statement

Large reasoning models are expensive. Many problems need less compute. Can we predict which problems are easy and route them to cheaper models without losing accuracy?

Main Contribution

Train lightweight classifiers on intermediate embeddings of s1.1-32B to predict problem difficulty (1–5) and per-model correctness (binary).

Design threshold-based routers that send each problem to the smallest model predicted to succeed.

Key Findings

Middle layers of a strong reasoning model carry the most signal for difficulty and correctness prediction.

Practical UseUse mid-layer embeddings (not final-token logits) as inputs to lightweight predictors for routing.

Evidence RefFigure 2; used layer 45 of s1.1-32B

Accuracy-based routing can match or slightly beat the large model's accuracy on evaluated math tasks while using less compute.

Numbersinference compute ≈ two-thirds of s1.1-32B

Practical UseYou can lower inference cost by ~33% with minimal accuracy loss by routing per-problem to smaller models chosen by a predictor.

Evidence RefFigure 4; Section 4

Results

MetricValueBaselineDeltaSplit / DatasetEvidenceEvidence Ref
inference compute≈ 2/3 of s1.1-32B computes1.1-32B≈ -33%MathCombined (evaluation split)Accuracy-matched or slightly better while using two-thirds computeFigure 4; Section 4
Accuracycomparable or slightly higher than s1.1-32B on evaluated benchmarkss1.1-32Bsmall improvement reported (no exact % given)MathCombined (evaluation split)Router can achieve comparable and even slightly better performance than s1.1-32BSection 4; Figure 4

What To Try In 7 Days

Collect a small labeled subset of your tasks and extract mid-layer embeddings from your strongest model.

Train a simple MLP to predict task difficulty or per-model correctness using those embeddings.

Implement a threshold-based router that forwards inputs to the smallest model with predicted success probability above the threshold; measure cost and accuracy trade-offs.

Optimization Features

Infra Optimization
reported ≈33% inference compute savings on evaluated math workloads
System Optimization
use a single representative model's embeddings for routing to reduce overhead
Inference Optimization
per-query model routing to smaller modelsthreshold-based gating on predicted difficulty/correctnessuse mid-layer embeddings as lightweight signals

Reproducibility

Code AvailableNo
Data AvailableYes
Open Source StatusPartial
LicenseUnknown

Risks & Boundaries

Limitations

Evaluation is focused on math reasoning datasets; generalization to other domains is untested.

Router uses embeddings from a single representative model (s1.1-32B); this may not generalize to heterogeneous model pools.

When Not To Use

If your workload is non-mathematical or differs strongly from the evaluated benchmarks.

If per-query latency is critical and routing overhead could negate savings.

Failure Modes

Predictor misclassification routes hard problems to weak models, causing accuracy drops.

Domain shift: embeddings from the representative model may not reflect other models' failure modes.

Core Entities

Models

s1.1-32BLlama-3.3-70B-InstructLlama-3.3-Nemotron-Super-49Bphi-4Llama-3.1-8B-InstructLlama-3.1-Nemotron-Nano-8BMixtral-8x7B-instructOLMo-2-1124-7B-Instruct

Metrics

Accuracyinference time / compute

Datasets

MATHMathCombinedGSM8kMinervaAIME24AMC23OlympiadBenchTheoremQA

Benchmarks

MathCombined evaluation splitMATH (difficulty-labeled)