Swap transformer blocks for small ordered 'learners' and run only as many as each token needs to cut inference cost with minimal accuracy, n

December 15, 20238 min

Overview

Production Readiness

0.6

Novelty Score

0.6

Cost Impact Score

0.7

Citation Count

2

Authors

Bartosz Wójcik, Alessio Devoto, Karol Pustelnik, Pasquale Minervini, Simone Scardapane

Links

Abstract / PDF

Why It Matters For Business

ACMs let you reduce inference compute and GPU latency while retaining model accuracy, enabling cheaper, faster deployment of pretrained transformers in latency- or energy-constrained settings.

Summary TLDR

The paper introduces Adaptive Computation Modules (ACMs): replace selected transformer blocks with a small ordered set of lightweight submodules (“learners”) and a per-token gate that selects how many learners to run. ACMs let the model spend less compute on easy tokens and more on hard ones. Converted ViT and Wav2Vec models show lower FLOPs and wall-clock latency for a range of user budgets while keeping accuracy roughly the same. The authors provide a three-phase conversion (distill learners, pretrain gates, end-to-end finetune) and an optimized Triton GPU implementation.

Problem Statement

Transformer layers often provide full width (all parameters) for every input token. Many tokens do not need the full layer capacity, so models waste compute. The paper asks: can we adapt width per token to cut inference cost while keeping accuracy?

Main Contribution

Adaptive Computation Module (ACM): an ordered set of learners plus a small gating net that selects how many learners to run per token.

A conversion recipe to turn pretrained transformers into ACMized variants: module-wise distillation, gate pretraining with artificial labels, then end-to-end finetuning.

Efficient GPU implementation (Triton kernels) showing near-linear latency scaling with executed learners and negligible overhead.

Empirical results on ImageNet-1k (ViT-B) and CommonVoice speech models showing improved compute-accuracy trade-offs versus token-dropping, MoEfication, and early-exit baselines.

Key Findings

ACMized ViT-B achieves the Pareto frontier of FLOPs vs accuracy on ImageNet-1k.

NumbersAdvantage especially below 12.5 GFLOPs (Fig.3)

On CommonVoice-es speech recognition, ACMized Wav2Vec models achieve lower word error rate at every tested compute budget.

NumbersLower WER across evaluated budgets (Fig.4)

Latency scales roughly linearly with average number of executed learners and ACM overhead is negligible on A100 GPUs.

NumbersSingle ACM-layer latency linear in avg learners; measurements on A100 (batch 128/256) (Fig.9, Fig.3)

Two-stage initialization (module-wise distillation + gate pretraining) speeds convergence versus end-to-end finetuning from random init.

NumbersPretraining stages reduce training time in experiments (Fig.1)

ACMs are sensitive to distribution shift: corruptions change gate behavior and accuracy drops.

NumbersGating choices become more extreme and compute/accuracy change with ImageNet-C severity (Figs.4-5)

Choosing number of learners N trades granularity vs performance; N=4 gave the best trade-off in their ViT experiments.

NumbersAfter Phase I, N=4 at full activation had 71.98% vs other N values in Table 3

Results

Accuracy

ValueACMized ViT-B on Pareto frontier; favorable across budgets, especially <12.5 GFLOPs

BaselineOriginal ViT-B and other conditional compute methods (A-ViT, MoEfication, Zero Time Waste)

Speech recognition WER vs compute

ValueACMized Wav2Vec achieves lower WER at every tested computational budget

BaselineMoEfication

Wall-clock latency vs average learners

ValueLatency scales roughly linearly with average number of executed learners; overhead near-negligible

BaselineStatic MLP layer / full module

Accuracy

ValueAfter module-wise distillation N=4 kN=1.0 -> 71.98% (ImageNet after phase I)

BaselineOther N choices (2,8,16) in same experiment

Who Should Care

What To Try In 7 Days

Convert one MLP block of a small ViT to an ACM (N=4) and run the authors' 3-phase distillation + finetune for a few epochs.

Measure average FLOPs and wall-clock latency on your A100 (or target GPU) and compare to baseline.

If you use Triton, implement the gated forward pass and test latency/sorting gains on batched inputs.

Optimization Features

Token Efficiency

  • Per-token variable compute allocation

Infra Optimization

  • Optimized A100 GPU execution patterns (batching, grouping by learner count)

Model Optimization

  • Conditional computation (adaptive width per token)
  • Module-wise knowledge distillation

System Optimization

  • Triton GPU kernels with autotuning and kernel fusion

Training Optimization

  • Three-stage conversion: distill learners, pretrain gates, end-to-end finetune
  • Auxiliary losses to enforce budget and diversity

Inference Optimization

  • Per-token gating to skip learners
  • Sorting tokens by chosen learner count to batch matmuls

Reproducibility

Data Urls

  • ImageNet-1k (public)
  • CommonVoice (public)
  • ImageNet-C (public)

Code Available

Data Available

Open Source Status

  • partial

Risks & Boundaries

Limitations

  • Requires three-phase conversion and finetuning; random init finetune converges slower.
  • Performance and routing are sensitive to distribution shift (ImageNet-C showed gating changes and accuracy drops).
  • Wall-clock gains depend on a specialized Triton implementation and good batch sizes.
  • Hyperparameters (N, β_target, auxiliary weights) need tuning per model and task.

When Not To Use

  • When you cannot retrain or finetune the model at all.
  • Small models where gating overhead outweighs savings.
  • Deployment targets without efficient batching or lacking Triton-like kernels.

Failure Modes

  • Gate collapse: gates choose same number of learners for all tokens, nullifying adaptivity.
  • Severe domain shift causes gates to select extreme budgets and degrade accuracy.
  • Poor learner distillation or skipped pretraining can lead to lower final accuracy.

Core Entities

Models

  • Vision Transformer (ViT-B)
  • Wav2Vec2 / XLS-R-300M (Wav2Vec family)
  • ACMized ViT-B
  • ACMized Wav2Vec

Metrics

  • Accuracy
  • Word Error Rate (WER)
  • Average FLOPs per sample
  • Wall-clock latency (A100 GPU)

Datasets

  • ImageNet-1k
  • CommonVoice (selected languages)
  • ImageNet-C (corruptions)

Benchmarks

  • Accuracy
  • CommonVoice WER
  • ImageNet-C robustness

Context Entities

Models

  • A-ViT (token dropping)
  • MoE
  • Zero Time Waste (Early Exiting)