Predict problem difficulty from LLM mid-layer embeddings and route each query to the smallest model likely to solve it, cutting compute with

November 5, 20256 min

Overview

Production Readiness

0.6

Novelty Score

0.5

Cost Impact Score

0.7

Citation Count

0

Authors

Bo Zhao, Berkcan Kapusuzoglu, Kartik Balasubramaniam, Sambit Sahu, Supriyo Chakraborty, Genta Indra Winata

Links

Abstract / PDF

Why It Matters For Business

Routing saves inference cost by sending easy queries to cheaper models. That lowers cloud bills and lets you scale reasoning services while keeping top-model accuracy.

Summary TLDR

Train small classifiers on intermediate LLM embeddings to predict problem difficulty or a model's chance of success. Use those predictions to route each problem to the smallest model likely to solve it. On mixed math benchmarks the router matches the big model's accuracy while using about two-thirds of its inference compute.

Problem Statement

Large reasoning models are expensive. Many problems need less compute. Can we predict which problems are easy and route them to cheaper models without losing accuracy?

Main Contribution

Train lightweight classifiers on intermediate embeddings of s1.1-32B to predict problem difficulty (1–5) and per-model correctness (binary).

Design threshold-based routers that send each problem to the smallest model predicted to succeed.

Show routing cuts inference compute (≈1/3 saved) while matching or slightly exceeding top-model accuracy on mixed math benchmarks.

Key Findings

Middle layers of a strong reasoning model carry the most signal for difficulty and correctness prediction.

Accuracy-based routing can match or slightly beat the large model's accuracy on evaluated math tasks while using less compute.

Numbersinference compute ≈ two-thirds of s1.1-32B

Difficulty-based routing outperforms random assignment between two models.

Results

inference compute

Value≈ 2/3 of s1.1-32B compute

Baselines1.1-32B

Accuracy

Valuecomparable or slightly higher than s1.1-32B on evaluated benchmarks

Baselines1.1-32B

difficulty/correctness prediction quality by layer

Valuemiddle layers best (used layer 45 for s1.1-32B)

Baselinefinal-layer outputs

Who Should Care

What To Try In 7 Days

Collect a small labeled subset of your tasks and extract mid-layer embeddings from your strongest model.

Train a simple MLP to predict task difficulty or per-model correctness using those embeddings.

Implement a threshold-based router that forwards inputs to the smallest model with predicted success probability above the threshold; measure cost and accuracy trade-offs.

Optimization Features

Infra Optimization

  • reported ≈33% inference compute savings on evaluated math workloads

System Optimization

  • use a single representative model's embeddings for routing to reduce overhead

Inference Optimization

  • per-query model routing to smaller models
  • threshold-based gating on predicted difficulty/correctness
  • use mid-layer embeddings as lightweight signals

Reproducibility

Data Available

Open Source Status

  • partial

Risks & Boundaries

Limitations

  • Evaluation is focused on math reasoning datasets; generalization to other domains is untested.
  • Router uses embeddings from a single representative model (s1.1-32B); this may not generalize to heterogeneous model pools.
  • Reported compute savings are aggregate; per-query latency and routing overhead are not fully quantified.

When Not To Use

  • If your workload is non-mathematical or differs strongly from the evaluated benchmarks.
  • If per-query latency is critical and routing overhead could negate savings.
  • If you cannot obtain mid-layer embeddings from a strong representative model.

Failure Modes

  • Predictor misclassification routes hard problems to weak models, causing accuracy drops.
  • Domain shift: embeddings from the representative model may not reflect other models' failure modes.
  • Calibration drift: prediction thresholds may need frequent retuning as models or data change.

Core Entities

Models

  • s1.1-32B
  • Llama-3.3-70B-Instruct
  • Llama-3.3-Nemotron-Super-49B
  • phi-4
  • Llama-3.1-8B-Instruct
  • Llama-3.1-Nemotron-Nano-8B
  • Mixtral-8x7B-instruct
  • OLMo-2-1124-7B-Instruct

Metrics

  • Accuracy
  • inference time / compute

Datasets

  • MATH
  • MathCombined
  • GSM8k
  • Minerva
  • AIME24
  • AMC23
  • OlympiadBench
  • TheoremQA

Benchmarks

  • MathCombined evaluation split
  • MATH (difficulty-labeled)