Jointly train retriever and medical LLM to improve accuracy, reduce hallucinations, and cut training cost

February 27, 20247 min

Overview

Production Readiness

0.6

Novelty Score

0.6

Cost Impact Score

0.8

Citation Count

11

Authors

Junda Wang, Zhichao Yang, Zonghai Yao, Hong Yu

Links

Abstract / PDF

Why It Matters For Business

Joint retriever+LLM fine-tuning yields better medical QA accuracy and explanations while cutting training compute by orders of magnitude versus large-domain pretraining, making domain-specialized models cheaper and faster to build.

Summary TLDR

This paper introduces JMLR, a training method that updates a retriever and an LLM together so the retriever learns which documents actually help the LLM produce correct medical answers. JMLR-13B reaches an average accuracy of 70.5% across medical QA benchmarks (vs Meditron-70B 68.9% and RAG-13B 67.7%) and improves factuality and rationale quality. Joint training uses an LLM-driven rank loss and dynamic sampling of candidate docs (top-30 sampled, top-7 used) and reduces training compute: JMLR-13B ~148 GPU hours vs Meditron-70B ~42,630 GPU hours. Evaluations include automated metrics (accuracy, UMLS-F, GPT-4 scoring) and small human expert comparisons.

Problem Statement

Medical LLMs hallucinate and can miss or misapply domain knowledge. Traditional RAG methods train retriever separately from LLMs or continue pretraining on domain text, which can be slow or misaligned. The paper asks: can simultaneously training retriever and LLM on QA pairs make retrieval more helpful and reduce hallucinations while saving compute?

Main Contribution

Propose JMLR, a joint training method that updates retriever and LLM together using an LLM-driven rank loss.

Show JMLR-13B achieves higher average accuracy on multiple medical QA benchmarks than prior open- and closed-source baselines.

Report large reductions in training compute compared to large-domain pretraining, and demonstrate improved rationale factuality and reduced hallucinations by automated metrics and expert review.

Key Findings

JMLR-13B achieves the highest reported average accuracy across evaluated medical QA sets.

NumbersAvg accuracy 70.5% (JMLR-13B) vs 68.9% (Meditron-70B)

Joint training on a 7B model gives large improvements over domain pretraining.

NumbersJMLR-7B avg 62.3% vs Meditron-7B 53.2% (≈+9.1pp)

JMLR greatly reduces training compute compared to large-scale medical pretraining.

NumbersJMLR-13B 148 GPU hours vs Meditron-70B 42,630 GPU hours

Rationale factuality and quality improve under JMLR.

NumbersUMLS-F 0.2463 (JMLR) vs 0.2187 (GPT-3.5); GPT-4 overall 4.3036 vs 4.062

Joint training (JMLR) outperforms frozen retriever RAG variants.

NumbersJMLR-13B avg 70.5% vs RAG-13B 67.7%

Results

Accuracy

Value70.5% (JMLR-13B)

Baseline68.9% (Meditron-70B)

Accuracy

Value62.3% (JMLR-7B)

Baseline53.2% (Meditron-7B)

Training compute (GPU hours)

Value148 GPU hours (JMLR-13B)

Baseline42,630 GPU hours (Meditron-70B)

UMLS factuality (F1)

Value0.2463 (JMLR-13B)

Baseline0.2187 (GPT-3.5)

GPT-4 overall explanation score (1-5)

Value4.3036 (JMLR-13B)

Baseline4.0620 (GPT-3.5)

Who Should Care

What To Try In 7 Days

Fine-tune a small open LLM with a ColBERT retriever using an LLM-driven rank loss on a narrow corpus.

Experiment with retrieving 7 documents per query (paper found 7 optimal) and compare accuracy vs 1–10 docs.

Run GPT-4 or domain-expert checks on generated rationales to track factuality (UMLS-F and GPT-4 scoring are used here).

Optimization Features

System Optimization

  • Use S2-Attn to handle long contexts efficiently

Training Optimization

  • Jointly update retriever and LLM to align retrieval with answer utility
  • Weighted sampling from top-30 retriever candidates during training

Reproducibility

License

  • CC-BY 4.0 (stated intent upon acceptance)

Code Available

Open Source Status

  • partial

Risks & Boundaries

Limitations

  • Focused only on medical QA; transfer to other domains untested.
  • Human evaluation limited to three doctors and a small sample, reducing statistical power.
  • Privacy and bias concerns remain due to source composition of training documents.

When Not To Use

  • If you lack a relevant domain corpus of documents to retrieve from.
  • For high-stakes clinical deployment without independent expert oversight and validation.
  • When legal/privacy rules prevent using or sharing needed retrieval documents.

Failure Modes

  • Retriever selects irrelevant or misleading documents, misleading the LLM.
  • Over-reliance on retrieved content can propagate biases present in guidelines or corpora.
  • Performance depends on retriever quality; poor initialization may limit gains.

Core Entities

Models

  • JMLR-7B
  • JMLR-13B
  • RAG-7B
  • RAG-13B
  • Meditron-7B
  • Meditron-70B
  • Llama-2-7B
  • GPT-3.5
  • GPT-4
  • Claude3-Opus
  • ColBERT

Metrics

  • Accuracy
  • UMLS-F (factuality F1)
  • GPT-4 score (1-5 Likert)
  • Cohen's Kappa

Datasets

  • MedQA
  • Amboss
  • MedMCQA
  • MMLU-Medical
  • PubMed
  • MIMIC-IV
  • medical textbooks

Benchmarks

  • USMLE-style MedQA
  • Amboss question bank
  • MedMCQA
  • MMLU-Medical