Split FFNs into sparse experts + a teacher-guided router to cut FLOPs and adapt LLMs with tiny data

August 15, 20247 min

Overview

Production Readiness

0.5

Novelty Score

0.6

Cost Impact Score

0.6

Citation Count

1

Authors

Zhongyu Zhao, Menghang Dong, Rongyu Zhang, Wenzhao Zheng, Yunpeng Zhang, Huanrui Yang, Dalong Du, Kurt Keutzer, Shanghang Zhang

Links

Abstract / PDF

Why It Matters For Business

FactorLLM can cut FFN compute and lower inference costs significantly while enabling fast domain adaptation with tiny datasets, enabling cheaper, faster deployment for task-specific LLMs.

Summary TLDR

FactorLLM splits a pretrained transformer feed-forward layer (FFN) into equal-size sparse subnetworks treated as Mixture-of-Experts (MoE). A small injected router is trained by a teacher-student Prior-Approximate Router (PAR) loss so only a few experts activate per token. On TinyLlama/MobileLlama, FactorLLM reduces FFN FLOPs dramatically (up to ~75% for 1R4E1K), lowers total compute ~30–50% in some settings, and retains around 85% of original accuracy after fine-tuning on very small amounts of data (0.03–0.04%). Code is available.

Problem Statement

Monolithic FFNs in transformers hold redundant, mixed knowledge and waste compute. We need a low-overhead way to split that knowledge so only task-relevant parts run at inference and the model can adapt with very little data.

Main Contribution

A simple factorization that permutes and partitions pretrained FFN weights into N equal subnetworks (experts) without changing weight values.

Prior-Approximate Router (PAR): a teacher-student routing loss that creates pseudo-labels from the original FFN to train a small injected router quickly.

Empirical demonstration on TinyLlama and MobileLlama: large FFN FLOPs cuts and fast adaptation using 0.03–0.04% of original training data.

Key Findings

Large FFN FLOPs can be cut heavily by activating fewer experts.

NumbersFFN GFLOPs reduced ~75% for 1R4E1K

Tradeoff: big FLOPs savings with modest accuracy loss.

NumbersTotal compute reduced ~50% and accuracy retained >85% (1R4E2K)

Router training via the PAR teacher-student loss improves expert specialization and final accuracy.

NumbersRouter-enabled model: 85.6% vs random-expert selection: 73.9% (relative difference ~11.7%)

FactorLLM adapts using tiny amounts of data compared to original pretraining.

NumbersMaintains >85% performance using 0.03–0.04% of training data (≈30M–50M tokens vs 3T)

Results

FFN GFLOPs reduction

Value~75% (FFN) for 1R4E1K

Baselinedense FFN (K=N)

Total compute reduction

Value~50% (1R4E2K)

Baselineoriginal model compute

Accuracy

Value≈85% of original accuracy

Baselinesource model upper bound

Data efficiency for fine-tuning

Value0.03–0.04% of original training data (≈30M–50M tokens)

Baselineoriginal pretraining (≈3T tokens)

Who Should Care

What To Try In 7 Days

Take a small LLM (TinyLlama/MobileLlama), permute and split FFN into 4 experts and implement a TopK router.

Train the injected router via PAR using a frozen teacher FFN and a small in-domain dataset (tens of millions tokens).

Benchmark 1R4E2K and 1R4E1K: measure GFLOPs, latency, and accuracy to pick tradeoff for production.

Optimization Features

Token Efficiency

  • converges with ~30M–50M tokens vs pretraining scale

Model Optimization

  • FFN factorization into equal-size experts
  • sparse expert activation (MoE-style)

Training Optimization

  • Prior-Approximate Router (PAR) teacher-student loss
  • freeze teacher FFNs and fine-tune only routers+experts
  • few-step fine-tuning on small datasets (0.03–0.04% data)

Inference Optimization

  • TopK router to activate K experts per token
  • reduced FFN compute via sparse activation

Reproducibility

Code Available

Open Source Status

  • partial

Risks & Boundaries

Limitations

  • Experiments limited to small LLMs (TinyLlama, MobileLlama); results may differ on larger models.
  • When FFN compute is reduced, attention becomes the new bottleneck and accuracy can drop (noted in Section 4.3).
  • Router conflicts and degraded performance can occur when adding many routers or experts (Sec. 4.5).
  • Paper does not release the exact training subset; reproduction may need dataset engineering.

When Not To Use

  • When you require full original accuracy for every task (FactorLLM incurs up to ~15% relative drop in some configs).
  • If attention layers dominate your compute and you cannot change them, FFN factorization yields limited overall speedup.
  • If you cannot invest even small fine-tuning steps to train the router and experts.

Failure Modes

  • Experts collapse into similar modules if no router training is used, reducing diversity and benefit (Ex0 vs Ex3).
  • Router allocation instability early in training; needs PAR pseudo-labels to stabilize (Section 4.4).
  • Scaling to many routers/experts can cause routing conflicts and worse accuracy (Section 4.5).

Core Entities

Models

  • TinyLlama
  • MobileLlama
  • FactorLLM (1R4E2K, 1R4E1K, 1R4E3K variants)

Metrics

  • Accuracy
  • GFLOPs (attention, FFN)
  • Relative maintenance (%)

Datasets

  • Pajama (subset used for training)

Benchmarks

  • HellaSwag
  • OpenBookQA
  • Winogrande
  • ARC-Easy
  • ARC-Challenge
  • BoolQ
  • PIQA
  • MMLU