Bonsai: prune large language models using only forward passes to cut memory needs and keep accuracy

February 8, 20248 min

Overview

Production Readiness

0.75

Novelty Score

0.7

Cost Impact Score

0.85

Citation Count

6

Authors

Steven Kolawole, Lucio Dery, Jean-François Kagy, Virginia Smith, Graham Neubig, Ameet Talwalkar

Links

Abstract / PDF

Why It Matters For Business

Bonsai makes structured LLM compression feasible on commodity GPUs, cutting memory needs and producing faster models so teams can reduce inference cost and enable on-device fine-tuning without enterprise hardware.

Summary TLDR

Bonsai is a structured pruning method that uses only forward passes and a regression on perturbations to rank and remove modules (attention heads, MLP dims). It reduces pruning memory by >2× (can run on ≈20GB), produces faster models (e.g., 1.58× inference speed) and competitive accuracy after lightweight post-pruning adaptation (PPA). Bonsai beats other forward-only methods (FLAP, Wanda variant) and matches or improves on some gradient-based methods while enabling pruning of 7–8B models on a single A6000 GPU.

Problem Statement

Structured pruning for LLMs usually needs gradients and lots of GPU memory. That makes pruning impractical for many users. The paper asks: can we pick which modules to remove using only inference (forward passes) to save memory while keeping accuracy and speed?

Main Contribution

Bonsai: a forward-pass-only structured pruning algorithm that estimates global module importance via regression on perturbative sub-model evaluations.

Informative-prior sampling: bias sub-model sampling with cheap forward-pass signals (Wanda, activation magnitude, fluctuation) to reduce evaluations.

Holistic global pruning (not layer-wise) and an iterative schedule plus Post-Pruning Adaptation (PPA) to recover accuracy on commodity hardware.

Key Findings

Bonsai cuts pruning memory requirements to inference-only levels, enabling pruning on ≈20GB devices instead of 80–160GB.

Numberspruning memory ≈20GB vs 80–160GB for gradient methods

At 50% sparsity Bonsai yields lower perplexity than FLAP on Wikitext-2.

NumbersLLaMA-2@50% PPL: Bonsai 12.38 vs FLAP 14.49

Bonsai-pruned models can be notably faster than semi-structured alternatives.

NumbersBonsai speedup 1.58× vs Wanda 2:4 speedup 1.14×

Post-pruning adaptation (PPA) recovers most lost accuracy.

NumbersLLaMA-2@50% PPL: no PPA 19.47 → PPA+distill 8.89

Bonsai enables pruning of 7B–8B models to 50% sparsity on a single A6000 GPU.

Numbers7B/8B models pruned to 50% on single A6000 reported

Bonsai outperforms some gradient-based structured pruning baselines after adaptation.

NumbersAverage score: Bonsai+PPA 50.63 vs LoRA-Prune 48.3 and LLM-Pruner 47.18

Results

Pruning memory requirement

Value≈20GB (forward-only Bonsai)

Baseline80–160GB (gradient-based structured pruning)

Wikitext-2 perplexity at 50% sparsity (LLaMA-2 7B)

Value12.38 (Bonsai)

Baseline14.49 (FLAP), base 0% = 5.11

Inference speedup

Value1.58× (Bonsai-pruned)

Baseline1.14× (Wanda 2:4 semi-structured)

Effect of Post-Pruning Adaptation (PPA) on Wikitext-2 PPL

ValueNo PPA 19.47 → PPA+distill 8.89

BaselineNo PPA

Average downstream score (Eleuther harness) after PPA

Value50.63 (Bonsai+PPA avg)

Baseline48.3 (LoRAPrune); 47.18 (LLM-Pruner)

Who Should Care

What To Try In 7 Days

Run Bonsai to prune a 7B model to ~50% on a single 48GB-class GPU to validate memory and latency benefits.

After pruning, run lightweight fine-tuning (PPA) on the same GPU to recover accuracy cheaply.

Compare latency and perplexity vs an off-the-shelf 3B model to evaluate replacement vs buy-an-existing-model tradeoff.

Optimization Features

Infra Optimization

  • Supports pruning on single A6000 / ≈20–48GB GPUs
  • Avoids need for multi-A100-class setups

Model Optimization

  • Structured pruning of attention heads and MLP dims
  • Global module importance ranking via regression on perturbations

System Optimization

  • Forward-pass-only pruning to avoid backward memory overhead
  • Configurable runtime-quality tradeoff (15 min → 4 hr)

Training Optimization

  • Enables post-pruning fine-tuning on same hardware (PPA)
  • Uses cached logits for distillation to avoid extra memory

Inference Optimization

  • Produces smaller models with real latency speedups (e.g., 1.58×)
  • Removes whole modules to shrink tensor dimensions (not just sparsify)

Reproducibility

Data Urls

  • Wikitext-2 (public)
  • C4 (public)
  • GSM8K (public)

Data Available

Open Source Status

  • unknown

Risks & Boundaries

Limitations

  • Longer runtime for best quality: optimal configs need ≈4 hours vs some baselines ≈1 hour.
  • Pruning can hurt specialized reasoning tasks (e.g., GSM8K) unless you include task data in PPA.
  • Sampling and regression are not adaptive in current form; more samples improve accuracy at cost of time.
  • Paper does not provide public code, so replication requires reimplementation.

When Not To Use

  • If you have abundant multi-GPU memory and prefer gradient-based, jointly optimized pruning during training.
  • When you need the absolute highest out-of-the-box reasoning performance without any fine-tuning.
  • If you require an off-the-shelf open-source tool and cannot reimplement the method.

Failure Modes

  • Too few perturbation samples can produce NaNs or catastrophic degradation in FP16 (observed with ns=50).
  • Overly aggressive per-iteration pruning (large p_iter) damages model irrecoverably.
  • Priors derived from activation magnitude may miss modules needed for niche tasks.

Core Entities

Models

  • LLaMA-1 7B
  • LLaMA-2 7B
  • LLaMA-3 8B
  • Phi-2 3B
  • Mistral-7B

Metrics

  • perplexity
  • inference speedup (×)
  • memory during pruning (GB)

Datasets

  • Wikitext-2
  • C4
  • GSM8K

Benchmarks

  • Eleuther LLM Evaluation Harness
  • HuggingFace OpenLLM leaderboard

Context Entities

Models

  • Phi-1.5
  • Sheared LlaMA

Metrics

  • Kendall rank correlation (used for cross-val)
  • sample variance / fluctuation metrics (FLAP)

Datasets

  • Wikitext-2 training/validation
  • C4 subsets used for pruning signal and PPA

Benchmarks

  • Wikitext-2 validation
  • GSM8K
  • ARC, Winogrande, HellaSwag, TruthfulQA, MMLU (via Eleuther harness)