Monarch Mixer: replace attention and MLPs with sub-quadratic GEMM-friendly layers to speed long-context models

October 18, 20239 min

Overview

Production Readiness

0.5

Novelty Score

0.7

Cost Impact Score

0.7

Citation Count

14

Authors

Daniel Y. Fu, Simran Arora, Jessica Grogan, Isys Johnson, Sabri Eyuboglu, Armin W. Thomas, Benjamin Spector, Michael Poli, Atri Rudra, Christopher Ré

Links

Abstract / PDF

Why It Matters For Business

If you run models with long contexts or want lower parameter cost, M2 can cut compute or model size and improve throughput on many GPUs while keeping accuracy; expect implementation and kernel work before production parity on all hardware.

Summary TLDR

Monarch Mixer (M2) is a new neural layer that replaces attention and dense MLPs with structured "Monarch" matrices. Monarch matrices are products of block-diagonal factors and permutations that compute with GEMMs (fast matrix multiplies), giving sub-quadratic compute in both sequence length and model dimension. In experiments M2 matches or exceeds Transformer baselines: it matches BERT GLUE quality with ~24–27% fewer parameters, beats ViT-b on ImageNet with ~50% fewer params, and matches GPT-style pretraining perplexity at 360M params. Key wins appear for long contexts: up to 9.1× throughput at 4K tokens on A100 and larger relative gains on GPUs with larger/faster caches. Causality for auto‑

Problem Statement

Transformers and MLPs scale quadratically in sequence length and/or model dimension, which makes long-context and wide models expensive. The paper asks whether a single, simple primitive can mix across sequence and model axes with sub-quadratic cost while keeping quality and hardware efficiency.

Main Contribution

Introduce Monarch matrices: a simple class of structured matrices computeable by block GEMMs and permutations that interpolate between O(N log N) and O(N^{3/2}) in length N.

Build Monarch Mixer (M2): an architecture that uses Monarch factors for both sequence and feature mixing, replacing attention and dense MLPs.

Develop a polynomial view (evaluation/interpolation) that yields a causal parameterization so M2 can be used for autoregressive (GPT-style) models with sub-quadratic cost.

Demonstrate across tasks: M2-BERT matches BERT GLUE with fewer params and large long-context throughput gains; M2-ViT matches/outperforms ViT with fewer params; M2-GPT matches/beat GPT/Hyena perplexity at tested scales.

Key Findings

M2-BERT matches BERT-base GLUE while cutting parameters.

NumbersGLUE 79.9 vs 79.6; −27% params (M2 80M vs BERT 110M)

Throughput improves dramatically at long sequences.

NumbersUp to 9.1× tokens/ms at sequence length 4096 on A100 (M2 vs HF BERT)

M2-ViT achieves higher ImageNet top-1 with fewer parameters.

NumbersTop-1 79.5% (M2-ViT-b 45M) vs 78.5% (ViT-b 87M); ~2× fewer params

Causal M2 can match or beat attention models in pretraining perplexity.

NumbersPerplexity at 15B tokens: M2-GPT (360M) = 9.0 vs Transformer (355M) = 9.1

Hardware efficiency depends on GPU cache and permutations.

NumbersFLOP utilization 25.6% on A100 (N=64k) and 41.4% on RTX 4090; wall-clock speedups grow with N

Results

GLUE average

Value79.9 (M2-BERT 80M); 80.9 (M2-BERT 110M)

Baseline79.6 (BERT-base 110M)

Tokens/ms (throughput)

Value353.9 tokens/ms at 4096 (M2-BERT-base 80M)

Baseline39.0 tokens/ms (HF BERT-base 110M)

Accuracy

Value79.5% (M2-ViT-b, 45M)

Baseline78.5% (ViT-b, 87M)

Perplexity (pretraining on PILE)

Value9.0 (M2-GPT 360M at 15B tokens)

Baseline9.1 (Transformer 355M at 15B tokens)

FLOP utilization

Value25.6% (A100, N=64k) ; 41.4% (RTX 4090, N=64k)

BaselineDense MLP ~80–98% utilization; FlashAttention ~24–37%

Who Should Care

What To Try In 7 Days

Replace attention + MLP in a small BERT with M2 and run GLUE fine-tuning to validate quality.

Benchmark M2 vs your current model on representative long-context inputs (2K–8K tokens) on target GPU.

Profile FLOP utilization, memory movement, and hotspots to assess whether permutations or block sizes need tuning.

Optimization Features

Token Efficiency

  • Enables much higher throughput for long-token contexts, enabling larger effective context windows

Infra Optimization

  • Stronger gains on GPUs with large/fast L2 (RTX 4090) than on A100 in baseline PyTorch
  • Manual CUDA kernels can improve utilization

Model Optimization

  • Structured block-diagonal Monarch matrices replace dense linear layers
  • Single primitive mixes both sequence and model axes

System Optimization

  • Relies on GEMMs + permutations; benefits from kernel fusion and IO-aware implementations
  • Performance sensitive to permutation/data movement costs

Training Optimization

  • Sub-quadratic FLOPs reduce pretraining wall-clock for long sequences
  • Demonstrated faster time-to-quality in some long‑sequence regimes

Inference Optimization

  • Higher tokens/ms at long contexts (≥2K); up to 9.1× on A100 at 4K
  • GEMM-based blocks enable high utilization on GPUs with large fast caches

Reproducibility

Code Available

Data Available

Open Source Status

  • partial

Risks & Boundaries

Limitations

  • Permutation and data-movement overheads reduce gains on some GPUs (A100) unless kernels are tuned.
  • Short-sequence inference can be slower than optimized attention kernels (FlashAttention) for small context lengths.
  • Causal parameterization requires careful polynomial constructions and may need padding/embedding blowups for some multivariate orders (p>2).
  • Reported implementations are proof-of-concept; further kernel and inference engineering are needed for production latency targets.

When Not To Use

  • Low-latency, short-sequence (<1K token) inference where FlashAttention is already highly optimized.
  • Environments where custom CUDA kernels or permutation-friendly memory layouts are impossible.
  • Applications that require immediate, battle-tested inference stacks with minimal engineering time.

Failure Modes

  • Permutations cause memory-bound bottlenecks, negating GEMM gains on some hardware.
  • FLOP savings do not translate to wall-clock speed without careful kernel/cache tuning.
  • Causal parameter choices or padding can increase memory or parameter overhead if misapplied.

Core Entities

Models

  • Monarch Mixer (M2)
  • M2-BERT
  • M2-ViT
  • M2-GPT
  • Monarch matrices (order p)
  • Hyena
  • ViT-b
  • BERT-base
  • BERT-large
  • FlashAttention

Metrics

  • GLUE average score
  • Accuracy
  • Perplexity on PILE
  • Tokens/ms (throughput)
  • FLOP utilization
  • Wall-clock speedup
  • Parameter count

Datasets

  • C4
  • GLUE
  • ImageNet-1k
  • The PILE
  • CIFAR-10
  • Speech-Commands-10

Benchmarks

  • GLUE
  • ImageNet-1k
  • PILE (pretraining perplexity)
  • tokens/ms throughput