Swap transformer blocks for small ordered 'learners' and run only as many as each token needs to cut inference cost with minimal accuracy, n

December 15, 20238 min

Overview

Decision SnapshotNeeds Validation

Method converts pretrained transformers and includes an optimized GPU implementation; the approach is ready for prototyping but needs per-task tuning and validation under distribution shifts.

Citations2

Evidence Strength0.70

Confidence0.78

Risk Signals10

Trust Signals

Findings with numeric evidence: 6/6

Findings with evidence refs: 6/6

Results with explicit delta: 4/4

Reproducibility

Status: Code + data available

Open source: Partial

At A Glance

Cost impact: 70%

Production readiness: 60%

Novelty: 60%

Authors

Bartosz Wójcik, Alessio Devoto, Karol Pustelnik, Pasquale Minervini, Simone Scardapane

Links

Abstract / PDF / Code / Data

Why It Matters For Business

ACMs let you reduce inference compute and GPU latency while retaining model accuracy, enabling cheaper, faster deployment of pretrained transformers in latency- or energy-constrained settings.

Who Should Care

Summary TLDR

The paper introduces Adaptive Computation Modules (ACMs): replace selected transformer blocks with a small ordered set of lightweight submodules (“learners”) and a per-token gate that selects how many learners to run. ACMs let the model spend less compute on easy tokens and more on hard ones. Converted ViT and Wav2Vec models show lower FLOPs and wall-clock latency for a range of user budgets while keeping accuracy roughly the same. The authors provide a three-phase conversion (distill learners, pretrain gates, end-to-end finetune) and an optimized Triton GPU implementation.

Problem Statement

Transformer layers often provide full width (all parameters) for every input token. Many tokens do not need the full layer capacity, so models waste compute. The paper asks: can we adapt width per token to cut inference cost while keeping accuracy?

Main Contribution

Adaptive Computation Module (ACM): an ordered set of learners plus a small gating net that selects how many learners to run per token.

A conversion recipe to turn pretrained transformers into ACMized variants: module-wise distillation, gate pretraining with artificial labels, then end-to-end finetuning.

Key Findings

ACMized ViT-B achieves the Pareto frontier of FLOPs vs accuracy on ImageNet-1k.

NumbersAdvantage especially below 12.5 GFLOPs (Fig.3)

Practical UseUse ACMs to reduce FLOPs under tight budget targets while keeping accuracy close to the original ViT-B.

Evidence RefFigure 3

On CommonVoice-es speech recognition, ACMized Wav2Vec models achieve lower word error rate at every tested compute budget.

NumbersLower WER across evaluated budgets (Fig.4)

Practical UseACMs work beyond vision; apply them to token-wise decoding tasks like speech to cut compute without increasing WER.

Evidence RefFigure 4

Results

MetricValueBaselineDeltaSplit / DatasetEvidenceEvidence Ref
AccuracyACMized ViT-B on Pareto frontier; favorable across budgets, especially <12.5 GFLOPsOriginal ViT-B and other conditional compute methods (A-ViT, MoEfication, Zero Time Waste)Better accuracy-vs-FLOPs trade-off under low-FLOPs targets (see Fig.3)ImageNet-1k validationFigure 3Figure 3
Speech recognition WER vs computeACMized Wav2Vec achieves lower WER at every tested computational budgetMoEficationLower WER across budgets (see Fig.4)CommonVoice (es) validationFigure 4Figure 4

What To Try In 7 Days

Convert one MLP block of a small ViT to an ACM (N=4) and run the authors' 3-phase distillation + finetune for a few epochs.

Measure average FLOPs and wall-clock latency on your A100 (or target GPU) and compare to baseline.

If you use Triton, implement the gated forward pass and test latency/sorting gains on batched inputs.

Optimization Features

Token Efficiency
Per-token variable compute allocation
Infra Optimization
Optimized A100 GPU execution patterns (batching, grouping by learner count)
Model Optimization
Conditional computation (adaptive width per token)Module-wise knowledge distillation
System Optimization
Triton GPU kernels with autotuning and kernel fusion
Training Optimization
Three-stage conversion: distill learners, pretrain gates, end-to-end finetuneAuxiliary losses to enforce budget and diversity
Inference Optimization
Per-token gating to skip learnersSorting tokens by chosen learner count to batch matmuls

Reproducibility

Code AvailableYes
Data AvailableYes
Open Source StatusPartial
LicenseUnknown

Data URLs

ImageNet-1k (public)CommonVoice (public)ImageNet-C (public)

Risks & Boundaries

Limitations

Requires three-phase conversion and finetuning; random init finetune converges slower.

Performance and routing are sensitive to distribution shift (ImageNet-C showed gating changes and accuracy drops).

When Not To Use

When you cannot retrain or finetune the model at all.

Small models where gating overhead outweighs savings.

Failure Modes

Gate collapse: gates choose same number of learners for all tokens, nullifying adaptivity.

Severe domain shift causes gates to select extreme budgets and degrade accuracy.

Core Entities

Models

Vision Transformer (ViT-B)Wav2Vec2 / XLS-R-300M (Wav2Vec family)ACMized ViT-BACMized Wav2Vec

Metrics

AccuracyWord Error Rate (WER)Average FLOPs per sampleWall-clock latency (A100 GPU)

Datasets

ImageNet-1kCommonVoice (selected languages)ImageNet-C (corruptions)

Benchmarks

AccuracyCommonVoice WERImageNet-C robustness

Context Entities

Models

A-ViT (token dropping)MoEZero Time Waste (Early Exiting)