Monarch Mixer: replace attention and MLPs with sub-quadratic GEMM-friendly layers to speed long-context models

October 18, 20239 min

Overview

Decision SnapshotNeeds Validation

The idea is practically useful but requires kernel-level and IO optimizations per GPU; experiments show promise across tasks but the implementation and hardware sensitivity limit immediate plug-and-play production use.

Citations14

Evidence Strength0.70

Confidence0.85

Risk Signals10

Trust Signals

Findings with numeric evidence: 5/5

Findings with evidence refs: 5/5

Results with explicit delta: 5/5

Reproducibility

Status: Code + data available

Open source: Partial

At A Glance

Cost impact: 70%

Production readiness: 50%

Novelty: 70%

Authors

Daniel Y. Fu, Simran Arora, Jessica Grogan, Isys Johnson, Sabri Eyuboglu, Armin W. Thomas, Benjamin Spector, Michael Poli, Atri Rudra, Christopher Ré

Links

Abstract / PDF

Why It Matters For Business

If you run models with long contexts or want lower parameter cost, M2 can cut compute or model size and improve throughput on many GPUs while keeping accuracy; expect implementation and kernel work before production parity on all hardware.

Who Should Care

Summary TLDR

Monarch Mixer (M2) is a new neural layer that replaces attention and dense MLPs with structured "Monarch" matrices. Monarch matrices are products of block-diagonal factors and permutations that compute with GEMMs (fast matrix multiplies), giving sub-quadratic compute in both sequence length and model dimension. In experiments M2 matches or exceeds Transformer baselines: it matches BERT GLUE quality with ~24–27% fewer parameters, beats ViT-b on ImageNet with ~50% fewer params, and matches GPT-style pretraining perplexity at 360M params. Key wins appear for long contexts: up to 9.1× throughput at 4K tokens on A100 and larger relative gains on GPUs with larger/faster caches. Causality for auto‑

Problem Statement

Transformers and MLPs scale quadratically in sequence length and/or model dimension, which makes long-context and wide models expensive. The paper asks whether a single, simple primitive can mix across sequence and model axes with sub-quadratic cost while keeping quality and hardware efficiency.

Main Contribution

Introduce Monarch matrices: a simple class of structured matrices computeable by block GEMMs and permutations that interpolate between O(N log N) and O(N^{3/2}) in length N.

Build Monarch Mixer (M2): an architecture that uses Monarch factors for both sequence and feature mixing, replacing attention and dense MLPs.

Key Findings

M2-BERT matches BERT-base GLUE while cutting parameters.

NumbersGLUE 79.9 vs 79.6; −27% params (M2 80M vs BERT 110M)

Practical UseYou can replace attention+MLP with M2 in BERT backbones to get similar quality with ~25–27% fewer parameters; try a drop-in M2-BERT for parameter-constrained settings.

Evidence RefTable 3, Table 9

Throughput improves dramatically at long sequences.

NumbersUp to 9.1× tokens/ms at sequence length 4096 on A100 (M2 vs HF BERT)

Practical UseIf your workloads use long contexts (≥2K tokens), M2 can cut inference time massively; benchmark M2 vs FlashAttention on your GPU and sequence lengths.

Evidence RefTable 5, Table 10

Results

MetricValueBaselineDeltaSplit / DatasetEvidenceEvidence Ref
GLUE average79.9 (M2-BERT 80M); 80.9 (M2-BERT 110M)79.6 (BERT-base 110M)+0.3 (vs BERT-base with −27% params) or +1.3 when param-matchedGLUE downstreamTable 3, Table 9Table 3
Tokens/ms (throughput)353.9 tokens/ms at 4096 (M2-BERT-base 80M)39.0 tokens/ms (HF BERT-base 110M)9.1× faster vs HF BERT-base at 4096A100 throughput by sequence lengthTable 5, Table 10Table 5

What To Try In 7 Days

Replace attention + MLP in a small BERT with M2 and run GLUE fine-tuning to validate quality.

Benchmark M2 vs your current model on representative long-context inputs (2K–8K tokens) on target GPU.

Profile FLOP utilization, memory movement, and hotspots to assess whether permutations or block sizes need tuning.

Optimization Features

Token Efficiency

Enables much higher throughput for long-token contexts, enabling larger effective context windows

Infra Optimization
Stronger gains on GPUs with large/fast L2 (RTX 4090) than on A100 in baseline PyTorchManual CUDA kernels can improve utilization
Model Optimization
Structured block-diagonal Monarch matrices replace dense linear layersSingle primitive mixes both sequence and model axes
System Optimization
Relies on GEMMs + permutations; benefits from kernel fusion and IO-aware implementationsPerformance sensitive to permutation/data movement costs
Training Optimization
Sub-quadratic FLOPs reduce pretraining wall-clock for long sequencesDemonstrated faster time-to-quality in some long‑sequence regimes
Inference Optimization
Higher tokens/ms at long contexts (≥2K); up to 9.1× on A100 at 4KGEMM-based blocks enable high utilization on GPUs with large fast caches

Reproducibility

Code AvailableYes
Data AvailableYes
Open Source StatusPartial
LicenseUnknown

Risks & Boundaries

Limitations

Permutation and data-movement overheads reduce gains on some GPUs (A100) unless kernels are tuned.

Short-sequence inference can be slower than optimized attention kernels (FlashAttention) for small context lengths.

When Not To Use

Low-latency, short-sequence (<1K token) inference where FlashAttention is already highly optimized.

Environments where custom CUDA kernels or permutation-friendly memory layouts are impossible.

Failure Modes

Permutations cause memory-bound bottlenecks, negating GEMM gains on some hardware.

FLOP savings do not translate to wall-clock speed without careful kernel/cache tuning.

Core Entities

Models

Monarch Mixer (M2)M2-BERTM2-ViTM2-GPTMonarch matrices (order p)HyenaViT-bBERT-baseBERT-largeFlashAttention

Metrics

GLUE average scoreAccuracyPerplexity on PILETokens/ms (throughput)FLOP utilizationWall-clock speedupParameter count

Datasets

C4GLUEImageNet-1kThe PILECIFAR-10Speech-Commands-10

Benchmarks

GLUEImageNet-1kPILE (pretraining perplexity)tokens/ms throughput