Fine-tune on the model's own correct answers to avoid forgetting and keep generality

September 7, 20247 min

Overview

Production Readiness

0.6

Novelty Score

0.6

Cost Impact Score

0.4

Citation Count

0

Authors

Sonam Gupta, Yatin Nandwani, Asaf Yehudai, Mayank Mishra, Gaurav Pandey, Dinesh Raghu, Sachindra Joshi

Links

Abstract / PDF

Why It Matters For Business

SSR lets you specialize a model for a task without erasing its existing skills, reducing risk when deploying fine-tuned LLMs across multiple use cases.

Summary TLDR

Selective Self-Rehearsal (SSR) is a simple fine-tuning recipe: run the base model on each train example, use an LLM judge to mark which model outputs are already acceptable, then fine-tune using the model's outputs for those examples and the gold labels for the rest. On content-grounded QA tasks, SSR matches or beats standard supervised fine-tuning (SFT) on the task while preserving the base model's general skills. For example, SFT caused average drops of up to 16.7% on standard benchmarks; SSR trimmed that drop to about 2% on the same tests.

Problem Statement

Fine-tuning on gold labels often overfits and erases useful skills learned earlier. Many inputs admit multiple valid outputs, yet standard SFT forces the gold label and shifts the model away from its prior output distribution. This hurts generalization on other datasets and benchmarks.

Main Contribution

Introduce Selective Self-Rehearsal (SSR): fine-tune on model-generated outputs when they are judged acceptable, otherwise use gold labels.

Operationalize correctness with an LLM-as-a-judge to pick which training examples use model outputs.

Empirically show SSR preserves base-model abilities (reasoning and general benchmarks) while learning the new task.

Evaluate on content-grounded QA using MD2D, NQ (augmented), and MuSiQue plus standard benchmarks (MMLU, TruthfulQA, GSM8k, Hellaswag).

Key Findings

SSR sharply reduces catastrophic forgetting on broad benchmarks compared to SFT.

NumbersSFT avg drop -16.7% vs SSR -2.3% (trained on MD2D) on MMLU/TruthfulQA/GSM8k/Hellaswag

SSR preserves answer quality (token-level recall) better than SFT in-domain and out-domain.

NumbersNQ: Mod. Recall SFT=71.2 vs SSR=74.7; MD2D: SSR Mod. Recall 65.6 (Table 3)

Human judges rate SSR better than SFT on out-of-domain data.

NumbersMuSiQue human relevance: SFT=1.83 vs SSR=2.78 (Likert 0-4)

Results

Avg performance change vs base on standard benchmarks (MMLU, TruthfulQA, GSM8k, Hellaswag)

ValueSFT -16.7%; SSR -2.3% (trained on MD2D)

Baselinebase model prompt

Modified recall (quality + answerability) on NQ in-domain

ValueSFT 71.2; SSR 74.7

Baselinebase model prompt (49.3 mod. recall)

Human relevance (Likert 0-4) on MuSiQue (out-of-domain)

ValuePrompt 2.47; SFT 1.83; SSR 2.78

Baselinebase model prompt

Who Should Care

What To Try In 7 Days

Run your base model on your fine-tune data and sample 1k outputs.

Use a strong LLM (or small rule set) to mark which outputs are acceptable.

Fine-tune with LoRA: use model outputs for accepted cases and gold labels otherwise; validate on an unrelated benchmark to check forgetting.

Optimization Features

Token Efficiency

  • Reduces need to replay instruction-tuning data

Infra Optimization

  • None; SSR adds pre-finetune inference and judge costs

System Optimization

  • Avoids auxiliary generative models for rehearsal

Training Optimization

  • Selective Self-Rehearsal (data selection to reuse model outputs)
  • LoRA

Reproducibility

Open Source Status

  • partial

Risks & Boundaries

Limitations

  • Requires running the base model on the entire training dataset, which adds significant inference cost.
  • Accuracy depends on the LLM-as-a-judge; judge errors can mislabel correct/incorrect outputs.
  • Human evaluation is small and performed by two in-house annotators, limiting external validity.

When Not To Use

  • When gold answers are unique and unambiguous (no valid alternative outputs).
  • When you cannot afford the extra inference and judge LLM costs.
  • When an unbiased, high-quality judge is unavailable.

Failure Modes

  • Lenient judge marks wrong model outputs as acceptable, causing the model to learn errors.
  • Strict judge marks valid model outputs as incorrect, reducing SSR to near-standard SFT and losing benefits.
  • Compute cost for scoring the whole dataset dominates project budget or timeline.

Core Entities

Models

  • Mistral-instruct-v2-7B
  • Mistral-7B-Instruct-v0.2
  • Mixtral-8x7B (used as judge)

Metrics

  • token-level recall
  • modified recall
  • Accuracy
  • human relevance (Likert 0-4)

Datasets

  • MultiDoc2Dial (MD2D)
  • Natural Questions (NQ) augmented
  • MuSiQue (augmented)

Benchmarks

  • MMLU
  • TruthfulQA
  • GSM8k
  • Hellaswag