SquareHead L2 distillation enables high-sparsity fine-tuning and real CPU/GPU inference speedups

October 10, 20238 min

Overview

Production Readiness

0.7

Novelty Score

0.6

Cost Impact Score

0.7

Citation Count

4

Authors

Eldar Kurtic, Denis Kuznedelev, Elias Frantar, Michael Goin, Dan Alistarh

Links

Abstract / PDF

Why It Matters For Business

Sparsity plus SquareHead can reduce LLM inference latency and cost on CPUs/GPUs (2–8x) while keeping accuracy for many tasks, enabling cheaper deployment on commodity hardware.

Summary TLDR

The paper shows a practical recipe to fine-tune large pretrained models while making most weights zero (sparsity) without losing accuracy. Key ingredient: SquareHead, an L2 per-token feature distillation loss that stabilizes fine-tuning and recovers accuracy at high sparsity. Applied to T5 (translation), Whisper (speech), and MPT-7B (generation) the authors get 1.5–3x CPU speedups at 50–70% sparsity with little or no accuracy loss, and up to ~6–9x when combining sparsity with INT8 quantization. Code and models are released.

Problem Statement

Pruning weights during brief fine-tuning often destabilizes large models or hurts accuracy. Standard losses (cross-entropy or logit KD) diverge or overfit at high sparsity. Separately, turning sparsity into real CPU/GPU speedups is nontrivial because of compute vs memory bottlenecks and GPU kernel support.

Main Contribution

Identify instability and overfitting issues when naively applying sparse fine-tuning to LLMs.

Introduce SquareHead: a normalized per-layer L2 (feature) distillation loss combined with task loss to stabilize sparse fine-tuning.

Demonstrate practical speedups from sparse models on CPU and GPU and show compatibility between sparsity and INT8 quantization.

Provide code, models, and runtime recipes for reproducing results across T5, Whisper, and MPT-7B.

Key Findings

SquareHead (L2 feature distillation) stabilizes sparse fine-tuning and recovers accuracy where CE and standard KD diverge.

T5 (EN→DE) preserves most BLEU at high sparsity: BLEU 25.91 → 24.67 at 75% sparsity.

NumbersBLEU 25.91 (0%) → 24.67 (75%), speedup 1.00x → 2.14x

Whisper-Small can be pruned to 50–67% sparsity with similar or slightly better WER; high sparsity (>80%) increases WER.

NumbersWER 32.57 (0%) → 30.87 (50%, 1.58x) → 31.85 (67%, 2.15x)

For MPT-7B on GSM8K, fine-tuning + SquareHead raised dense accuracy from 28.2% to 33.0%, and 60–70% sparsity can be essentially lossless in FP32.

NumbersDense 28.2% → fine-tuned dense 33.0%; FP32 60%: 28.8% (2.07x), 70%: 28.0% (2.62x)

Combining sparsity with INT8 quantization compounds speedups: e.g., MPT-7B 60% INT8 ≈ 6.7x speedup, 70% INT8 ≈ 7.5x, 80% INT8 ≈ 9.08x (with accuracy drop).

Numbers60% INT8 speedup 6.70x; 70% INT8 7.49x; 80% INT8 9.08x (Table 3)

A custom GPU N:M sparse kernel achieved 1.82× speedup over dense FP16 for a 4096×12288 QKV matrix.

NumbersGPU kernel 1.82× vs FP16 dense

Results

T5 BLEU vs sparsity

ValueBLEU 25.91 → 24.67 at 75% sparsity

BaselineBLEU 25.91 (0% sparsity)

Whisper WER vs sparsity

ValueWER 32.57 → 30.87 at 50% sparsity (speedup 1.58x)

BaselineWER 32.57 (0% sparsity)

Accuracy

ValueFP32 accuracy 28.2% → 28.0% at 70% sparsity with 2.62x CPU speedup

BaselineAccuracy 28.2% (0% sparsity)

MPT-7B INT8 compound speedup

Value60% INT8 ≈ 6.70x, 70% INT8 ≈ 7.49x, 80% INT8 ≈ 9.08x (with accuracy drop)

BaselineFP32 dense 1.00x

GPU kernel speedup for QKV layer

ValueCustom N:M kernel 1.82× vs dense FP16

BaselineDense FP16 baseline

Who Should Care

What To Try In 7 Days

Run SparseGPT one-shot prune then fine-tune with SquareHead on a small target dataset.

Benchmark end-to-end latency in DeepSparse at 50–70% sparsity first.

If memory-bound, try INT8 post-training quantization on the sparse model to amplify speedups.

Optimization Features

Infra Optimization

  • commodity CPU inference
  • GPU-aware N:M format kernels

Model Optimization

  • pruning
  • sparsity

System Optimization

  • memory-bandwidth reduction via compressed sparse weights

Training Optimization

  • knowledge distillation (SquareHead)
  • one-shot pruning + fine-tune

Inference Optimization

  • CPU sparse kernels (DeepSparse)
  • GPU N:M sparse kernels
  • INT8 quantization compatibility

Reproducibility

Code Available

Open Source Status

  • partial

Risks & Boundaries

Limitations

  • Very high sparsity (>80%) often degrades accuracy for some tasks.
  • Sparse GPU speedups need N:M formats and custom kernels; not all hardware benefits equally.
  • Fine-tuning data scarcity can still cause overfitting; SquareHead reduces but does not eliminate this risk.

When Not To Use

  • When absolute top accuracy matters and any drop is unacceptable.
  • When target hardware lacks support for the sparse formats or custom kernels used.
  • When you cannot run the extra fine-tuning cycles required after pruning.

Failure Modes

  • Training divergence or loss spikes when using only task loss or standard KD at high sparsity.
  • Compound accuracy drop when combining high sparsity and post-training quantization.
  • Runtime gains may vanish if sparse kernels are not well-optimized for your shapes/hardware.

Core Entities

Models

  • MPT-7B
  • T5-Small
  • Whisper-Small
  • SparseGPT

Metrics

  • BLEU
  • WER
  • Accuracy
  • tokens/sec
  • end-to-end speedup

Datasets

  • GSM8K
  • WMT14 (EN-DE)
  • CommonVoice (Hindi)

Benchmarks

  • GSM8K
  • WMT14
  • CommonVoice

Context Entities

Models

  • Dense teacher fine-tuned baselines