One-shot clustering of MLP subunits that preserves NTK to speed up fine-tuning of dense and MoE models

July 18, 20238 min

Overview

Production Readiness

0.7

Novelty Score

0.6

Cost Impact Score

0.7

Citation Count

2

Authors

Mengting Ai, Tianxin Wei, Yifan Chen, Zeming Guo, Jingrui He

Links

Abstract / PDF

Why It Matters For Business

MLP Fusion reduces GPU memory and fine-tuning time while preserving training dynamics and near-original accuracy, making low-cost SFT and smaller deployed models feasible for companies running many custom fine-tunes or deploying large MoE models.

Summary TLDR

The paper introduces MLP Fusion, a one-shot, data-agnostic compression that clusters the bottleneck-1 sub-MLPs inside FFN/MLP modules and reconstructs a smaller MLP with a standalone scaling matrix. The design explicitly targets the Neural Tangent Kernel (NTK) of the original model so the compressed model keeps similar training dynamics under Adam-like optimizers. Experiments on RoBERTa, GPT-2 and a Switch Transformer MoE show MLP Fusion best preserves NTK, keeps accuracy close to the full model (e.g., SST2: 94.61 → 93.23), and reduces fine-tuning memory and wall-clock time roughly to levels of structured SVD/pruning. Code: https://github.com/weitianxin/MLP_Fusion.

Problem Statement

Fine-tuning large PLMs is costly, especially because FFN/MLP modules (and MoE experts) dominate compute and memory. Existing one-shot compression methods focus on output approximation or sparsity but often hurt training dynamics or yield no practical speedups on modern GPUs. The paper asks: can we compress MLPs in one shot while preserving the model's training dynamics so fine-tuning remains effective and efficient?

Main Contribution

Propose NTK-driven one-shot compression ("MLP Fusion") that clusters MLP's bottleneck-1 subunits and reconstructs a smaller MLP while keeping a standalone cluster-scaling matrix.

Show theoretical derivations that clustering can preserve the Adam NTK under stated assumptions, linking output approximation to training-dynamics preservation.

Extend the method to Mixture-of-Experts (MoE) by compressing experts independently and keeping routing compatible with the compressed experts.

Provide extensive experiments on NLU and NLG benchmarks showing superior NTK preservation, small accuracy loss, and practical memory/runtime gains; release code.

Key Findings

MLP Fusion yields the lowest NTK approximation error among tested one-shot methods.

NumbersNTK error on SST2 (RoBERTa first layer): 2826.6 ±155.1 vs SVD 4423.4

Compressed models retain most downstream accuracy with modest loss.

NumbersSST2 accuracy: full RoBERTa 94.61% → MLP Fusion 93.23%; +1-epoch layer-wise tuning → 93.79%

MLP Fusion reduces fine-tuning memory and wall-clock time similar to structured SVD/pruning.

NumbersTraining memory 24.02GB→15.46GB; SFT runtime 6697s→5782s; inference runtime 170.33s→162.37s

Method applies to MoE models with similar accuracy preservation.

NumbersSwitch Transformer SST2 acc: full 95.60% → MLP Fusion 95.03%; +tuning 95.15%

Results

NTK approximation error (L2)

Value2826.59 ±155.06

BaselineSVD 4423.38 ±108.89

Output approximation error (L2)

Value4.83 ±0.02

BaselineSketch 24.48 ±0.61

Accuracy

Value93.23 ±0.23

BaselineFull RoBERTa 94.61 ±0.09

Fine-tuning memory

Value15.46 GB

BaselineFull 24.02 GB

Fine-tuning runtime (wall-clock)

Value5782.12 s ±8.90

BaselineFull 6697.01 s ±4.76

Who Should Care

What To Try In 7 Days

Run MLP Fusion on the last 8 MLP layers of your RoBERTa/GPT-2 model with intermediate dim → 25% (e.g., 3072→768).

Compare NTK approximation error and dev accuracy versus simple SVD or sketching to validate dynamics preservation.

If accuracy dips, run 1 epoch of layer-wise task-specific tuning (MSE on teacher layer outputs).

Optimization Features

Infra Optimization

  • works on standard GPU without sparse-kernel support

Model Optimization

  • clustering-based one-shot compression of MLP subunits
  • standalone diagonal cluster-scaling matrix to mimic per-cluster learning rates
  • applies to MoE experts (expert-wise fusion)

System Optimization

  • SFT
  • lower wall-clock fine-tuning time in experiments

Training Optimization

  • preserves Adam NTK (training dynamics) under stated assumptions
  • short layer-wise task-specific tuning (1 epoch) to regain task knowledge

Inference Optimization

  • reduces model parameter count and MLP GFLOPs similar to SVD/structured pruning
  • no reliance on sparse matrices (keeps hardware-friendly dense ops)

Reproducibility

Code Available

Data Available

Open Source Status

  • partial

Risks & Boundaries

Limitations

  • Performance drops if the fused intermediate dimension is too small (experiments show rank-deficiency below ~768).
  • NTK preservation relies on assumptions (clustering approximates weight distribution and Adam-like sign behavior); deviations can reduce effectiveness.
  • All experiments run on a single V100 32GB GPU; larger-scale behavior and multi-GPU speedups are not demonstrated.

When Not To Use

  • When you need zero accuracy loss — fusion causes small but nonzero accuracy drops.
  • When you must compress MLPs far below 25% retained parameters — extreme compression degrades training and accuracy.

Failure Modes

  • Clustering fails to capture crucial sub-MLP diversity, causing NTK mismatch and worse fine-tuning.
  • Too-small fused dimensions create rank-deficient layers with higher training loss.
  • For MoE, incorrect handling of router weights (not freezing when needed) can hurt performance.

Core Entities

Models

  • RoBERTa
  • GPT-2
  • Switch Transformer
  • DistilRoBERTa
  • NTK-SAP
  • LoRA
  • SVD

Metrics

  • Accuracy
  • BLEU
  • METEOR
  • TER
  • NTK approximation error
  • output approximation error
  • memory (GB)
  • wall-clock runtime (s)

Datasets

  • SST2
  • MNLI
  • MRPC
  • CoLA
  • WebNLG
  • STS-B
  • QNLI

Benchmarks

  • GLUE
  • WebNLG

Context Entities

Models

  • Mixtral (cited as large MoE example)