One-shot clustering of MLP subunits that preserves NTK to speed up fine-tuning of dense and MoE models

July 18, 20238 min

Overview

Decision SnapshotReady For Pilot

The method is practical: code is released, experiments include dense and MoE models, and measured memory/runtime gains are shown; theory supports NTK preservation under clear assumptions.

Citations2

Evidence Strength0.80

Confidence0.85

Risk Signals8

Trust Signals

Findings with numeric evidence: 4/4

Findings with evidence refs: 4/4

Results with explicit delta: 5/5

Reproducibility

Status: Code + data available

Open source: Partial

At A Glance

Cost impact: 70%

Production readiness: 70%

Novelty: 60%

Authors

Mengting Ai, Tianxin Wei, Yifan Chen, Zeming Guo, Jingrui He

Links

Abstract / PDF / Code

Why It Matters For Business

MLP Fusion reduces GPU memory and fine-tuning time while preserving training dynamics and near-original accuracy, making low-cost SFT and smaller deployed models feasible for companies running many custom fine-tunes or deploying large MoE models.

Who Should Care

Summary TLDR

The paper introduces MLP Fusion, a one-shot, data-agnostic compression that clusters the bottleneck-1 sub-MLPs inside FFN/MLP modules and reconstructs a smaller MLP with a standalone scaling matrix. The design explicitly targets the Neural Tangent Kernel (NTK) of the original model so the compressed model keeps similar training dynamics under Adam-like optimizers. Experiments on RoBERTa, GPT-2 and a Switch Transformer MoE show MLP Fusion best preserves NTK, keeps accuracy close to the full model (e.g., SST2: 94.61 → 93.23), and reduces fine-tuning memory and wall-clock time roughly to levels of structured SVD/pruning. Code: https://github.com/weitianxin/MLP_Fusion.

Problem Statement

Fine-tuning large PLMs is costly, especially because FFN/MLP modules (and MoE experts) dominate compute and memory. Existing one-shot compression methods focus on output approximation or sparsity but often hurt training dynamics or yield no practical speedups on modern GPUs. The paper asks: can we compress MLPs in one shot while preserving the model's training dynamics so fine-tuning remains effective and efficient?

Main Contribution

Propose NTK-driven one-shot compression ("MLP Fusion") that clusters MLP's bottleneck-1 subunits and reconstructs a smaller MLP while keeping a standalone cluster-scaling matrix.

Show theoretical derivations that clustering can preserve the Adam NTK under stated assumptions, linking output approximation to training-dynamics preservation.

Key Findings

MLP Fusion yields the lowest NTK approximation error among tested one-shot methods.

NumbersNTK error on SST2 (RoBERTa first layer): 2826.6 ±155.1 vs SVD 4423.4

Practical UseChoose MLP Fusion when you want a compressed model that keeps similar training dynamics under Adam and thus better fine-tuning behavior.

Evidence RefTable I

Compressed models retain most downstream accuracy with modest loss.

NumbersSST2 accuracy: full RoBERTa 94.61% → MLP Fusion 93.23%; +1-epoch layer-wise tuning → 93.79%

Practical UseExpect about 1–1.5 percentage points drop on SST2 after one-shot fusion; a short layer-wise tuning pass recovers part of the gap.

Evidence RefTable II; Appendix E

Results

MetricValueBaselineDeltaSplit / DatasetEvidenceEvidence Ref
NTK approximation error (L2)2826.59 ±155.06SVD 4423.38 ±108.89better (smaller) than SVD by ~1596.8SST2 validation (RoBERTa first-layer MLP)MLP Fusion has lowest NTK error among compared one-shot methodsTable I
Output approximation error (L2)4.83 ±0.02Sketch 24.48 ±0.61much lower than Sketch; comparable to ClusteringSST2 validation (RoBERTa first-layer MLP)MLP Fusion matches best output approximation among clustering methodsTable I

What To Try In 7 Days

Run MLP Fusion on the last 8 MLP layers of your RoBERTa/GPT-2 model with intermediate dim → 25% (e.g., 3072→768).

Compare NTK approximation error and dev accuracy versus simple SVD or sketching to validate dynamics preservation.

If accuracy dips, run 1 epoch of layer-wise task-specific tuning (MSE on teacher layer outputs).

Optimization Features

Infra Optimization
works on standard GPU without sparse-kernel support
Model Optimization
clustering-based one-shot compression of MLP subunitsstandalone diagonal cluster-scaling matrix to mimic per-cluster learning ratesapplies to MoE experts (expert-wise fusion)
System Optimization
SFTlower wall-clock fine-tuning time in experiments
Training Optimization
preserves Adam NTK (training dynamics) under stated assumptionsshort layer-wise task-specific tuning (1 epoch) to regain task knowledge
Inference Optimization
reduces model parameter count and MLP GFLOPs similar to SVD/structured pruningno reliance on sparse matrices (keeps hardware-friendly dense ops)

Reproducibility

Code AvailableYes
Data AvailableYes
Open Source StatusPartial
LicenseUnknown

Risks & Boundaries

Limitations

Performance drops if the fused intermediate dimension is too small (experiments show rank-deficiency below ~768).

NTK preservation relies on assumptions (clustering approximates weight distribution and Adam-like sign behavior); deviations can reduce effectiveness.

When Not To Use

When you need zero accuracy loss — fusion causes small but nonzero accuracy drops.

When you must compress MLPs far below 25% retained parameters — extreme compression degrades training and accuracy.

Failure Modes

Clustering fails to capture crucial sub-MLP diversity, causing NTK mismatch and worse fine-tuning.

Too-small fused dimensions create rank-deficient layers with higher training loss.

Core Entities

Models

RoBERTaGPT-2Switch TransformerDistilRoBERTaNTK-SAPLoRASVD

Metrics

AccuracyBLEUMETEORTERNTK approximation erroroutput approximation errormemory (GB)wall-clock runtime (s)

Datasets

SST2MNLIMRPCCoLAWebNLGSTS-BQNLI

Benchmarks

GLUEWebNLG

Context Entities

Models

Mixtral (cited as large MoE example)