Pick fine‑tuning data by clustering loss curves of a small proxy model

March 12, 20247 min

Overview

Production Readiness

0.7

Novelty Score

0.6

Cost Impact Score

0.8

Citation Count

1

Authors

Yu Yang, Siddhartha Mishra, Jeffrey N Chiang, Baharan Mirzasoleiman

Links

Abstract / PDF

Why It Matters For Business

S2L can cut fine‑tuning data by up to ~89% on the evaluated math tasks and halve data/train time in clinical summarization, lowering compute, storage, and labeling costs while keeping or improving accuracy.

Summary TLDR

S2L records loss values over training for each example on a small, cheap proxy model, clusters those loss trajectories, then draws balanced samples from all clusters to make a fine‑tuning subset for a large model. The paper proves clustered trajectories imply similar gradients and gives a convergence bound for training on the subset. Empirically, S2L matches full-data performance with just 11% of MathInstruct, improves average accuracy over SOTA selection by ~4.7% across six math datasets, gives 32.7% on the hard MATH benchmark from 50K examples (+16.6% vs Phi-2), and improves clinical summarization while cutting data in half. Code is public.

Problem Statement

Supervised fine‑tuning (SFT) for specialized domains is expensive and data‑hungry. Existing selection methods rely on embeddings or confidence from large reference models, which (1) can fail when fine‑tuning data differs from pretraining data and (2) are costly to compute for large models. The paper asks: can a small proxy model's training dynamics identify a small, high‑quality subset that trains large models almost as well?

Main Contribution

S2L algorithm: cluster per-example loss trajectories from a small proxy and uniformly sample across clusters to build a training subset.

Theory: prove examples in the same loss-trajectory cluster have similar gradients and give a bounded-gradient-error convergence guarantee for incremental gradient training on the subset.

Empirical: across math and clinical tasks, S2L reduces needed training data and often matches or improves full-data training while using a proxy up to 100× smaller.

Key Findings

S2L matches full MathInstruct performance using only ~11% of the data.

Numbers11% of MathInstruct (~30K of 262K)

S2L outperforms other open-source data selection methods on average.

Numbersavg +4.7% accuracy vs SOTA across 6 datasets

On the MATH benchmark, S2L with 50K examples achieves 32.7% exact match.

Numbers32.7% exact match; +16.6% vs pretrained Phi‑2

For clinical summarization on MIMIC‑III, S2L with half the data outperforms full‑data training on several metrics.

Numbers30K vs 61.5K; higher ROUGE‑L and BERTScore

S2L can use a very small proxy to scale selection cheaply.

Numbersproxy as small as 70M vs target up to 7B (≈100× smaller)

Results

Accuracy

ValueS2L with 30K matches/exceeds full-data training

BaselineFull MathInstruct (262K)

Accuracy

Value32.7% with 50K S2L-selected examples

Baselinepretrained Phi-2

ROUGE-L and BERTScore for clinical summarization

ValueS2L(30K) > Full(61.5K) on ROUGE-L & BERTScore

BaselineFull MIMIC-III (61.5K)

Who Should Care

What To Try In 7 Days

Train a small proxy (≈70–160M) on your domain for a few epochs and record per-example loss every few hundred steps.

Cluster loss trajectories (K≈100 using Faiss KMeans) and sample uniformly from each cluster to build a subset for your data budget.

Fine‑tune your production model on the selected subset and compare exact match / ROUGE / BERTScore against random and full-data baselines.

Optimization Features

Infra Optimization

  • reduces selection cost (proxy ~100× smaller)

System Optimization

  • use small proxy to shrink selection compute

Training Optimization

  • data-efficient training
  • subset sampling to reduce epochs and storage

Reproducibility

License

  • MathInstruct: MIT; MIMIC-III: DUA required

Code Available

Data Available

Open Source Status

  • partial

Risks & Boundaries

Limitations

  • Tested on two domains (mathematics, clinical summarization) only.
  • Experiments limited to models up to 7B parameters; larger targets untested here.
  • Uniform training schedule used for comparisons; per-method hyperparameter tuning was not performed.

When Not To Use

  • When domain data lacks consistent training dynamics across scales (proxy may not reflect target).
  • If you require selection sensitive to rare but critical examples that clustering may under-represent.
  • When you cannot train even a small proxy due to privacy or compute constraints.

Failure Modes

  • Proxy model fails to capture target dynamics, producing poor clusters and a low‑quality subset.
  • Clusters emphasize easier or common topics and under-sample rare but important cases.
  • Choice of clustering K or sparse trajectories may reduce effectiveness.

Core Entities

Models

  • Pythia-70M
  • Pythia-160M
  • Pythia-410M
  • Pythia-1B
  • Pythia-2.8B
  • Phi-2 (2.7B)
  • Phi-3-MINI (3.8B)
  • LLaMA-2-7B
  • GPT-2 (124M)

Metrics

  • Exact match
  • BLEU
  • ROUGE-L
  • BERTScore

Datasets

  • MathInstruct
  • MATH
  • GSM8K
  • NumGLUE
  • SVAMP
  • SimulEq
  • MIMIC-III

Benchmarks

  • MATH
  • GSM8K
  • NumGLUE
  • SVAMP
  • SimulEq