Fast Multipole Attention: a physics-inspired multilevel attention that cuts attention cost to O(n log n) or O(n)

October 18, 20238 min

Overview

Production Readiness

0.7

Novelty Score

0.7

Cost Impact Score

0.8

Citation Count

0

Authors

Yanming Kang, Giang Tran, Hans De Sterck

Links

Abstract / PDF

Why It Matters For Business

FMA lowers GPU memory and inference latency for long text and high-resolution images, letting teams train bigger models or use longer contexts without buying more hardware.

Summary TLDR

This paper introduces Fast Multipole Attention (FMA), a learned multilevel attention that keeps exact attention in a small local neighborhood and summarizes distant context with learned group summaries. FMA reduces self-attention cost from quadratic to O(n log n) and, with query downsampling, to O(n). The authors provide 1D and 2D implementations and show better accuracy-than-other-efficient methods on long-context language (enwik8, WikiText-103) and higher accuracy + lower memory on image tasks (ImageNet, ADE20K). Code is promised on GitHub.

Problem Statement

Full self-attention gives a global receptive field but costs O(n^2) time and memory, which blocks training or inference on long text and high-resolution images. The paper aims to keep global context while reducing cost to practical levels.

Main Contribution

Fast Multipole Attention (FMA): a learned multilevel attention that keeps exact attention in a fixed local window and aggregates distant tokens via learned summaries.

Theoretical complexity: O(n log n) time/memory and O(n) with query downsampling, for both 1D (text) and 2D (images).

Practical 1D and 2D implementations with TVM/CUDA kernels that avoid materializing far-field blocks on device memory.

Empirical results showing FMA outperforms other efficient attention methods on language benchmarks and improves ImageNet and ADE20K accuracy while lowering GPU memory.

Hyperparameter study of base cell size r and coarse rank p that clarifies performance vs memory trade-offs.

Key Findings

FMA changes attention complexity from quadratic to log-linear or linear.

NumbersComplexity reduced from O(n^2) to O(n log n); O(n) with query downsampling

On character-level enwik8 with context 4096, FMA achieves best efficient-model bpc.

NumbersFMA bpc=1.133 at n=4096; Reformer 1.180

On WikiText-103 masked language modeling at context 2048, FMA nearly matches full attention perplexity.

NumbersFMA ppl=8.95 vs Full attention ppl=8.70 at n=2048

FMA2D improves ImageNet top-1 accuracy and lowers memory vs ViT and Swin baselines.

NumbersFMA2D-B top-1=84.0% (224) vs ViT-B/16 81.8%; memory 14.6GB vs 17.2GB

FMA2D improves segmentation quality on ADE20K.

NumbersFMA2D-L-SegFormer mIoU=52.4% vs SegFormer-B5 51.5% (+0.9)

Memory and latency scale much better than quadratic attention in practice.

NumbersFMA-linear shows linear memory growth and stays competitive up to 8192 tokens (Figure 9)

Results

bits-per-character (bpc)

Value1.133 (FMA, n=4096)

BaselineReformer 1.180 (n=4096)

perplexity (ppl)

Value8.95 (FMA, n=2048)

BaselineFull attention 8.70 (n=2048)

Accuracy

Value84.0 (FMA2D-B, 224)

BaselineViT-B/16 81.8 (224)

mIoU (%)

Value52.4 (FMA2D-L-SegFormer)

BaselineSegFormer-B5 51.5

peak GPU memory (GB)

Value14.6 (FMA2D-B, 224)

BaselineViT-B/16 17.2 (224)

Who Should Care

What To Try In 7 Days

Replace standard self-attention with FMA in one Transformer layer to measure memory and accuracy changes.

Run a short language eval (e.g., WikiText-103 subset) with FMA and FMA-linear to compare perplexity and peak memory.

For vision, swap attention blocks in a ViT backbone with FMA2D on a small ImageNet subset to check top-1 and GPU footprint.

Optimization Features

Token Efficiency

  • Multilevel summarization preserves local detail while compressing distant context

Infra Optimization

  • Lower peak GPU memory enables larger batch sizes or higher resolution inputs on the same hardware

Model Optimization

  • Learned low-rank group summaries (coarse rank p)
  • Separable learned aggregation kernels for 2D groups

System Optimization

  • Custom TVM schedule and optimized CUDA kernels to realize theoretical speedups

Training Optimization

  • End-to-end learning of aggregation kernels; same training recipes as baselines

Inference Optimization

  • Avoids materializing off-diagonal blocks; on-the-fly generation of block contributions
  • Query downsampling option to reach strict O(n) inference

Reproducibility

Code Available

Data Available

Open Source Status

  • partial

Risks & Boundaries

Limitations

  • Two hyperparameters require tuning (base cell size r and coarse rank p) and affect memory/accuracy trade-offs (paper studies r and p).
  • FMA-linear trades a modest accuracy loss for memory savings compared to the log-linear variant.
  • Current CUDA kernels may not be fully optimized; runtime could change with further engineering.

When Not To Use

  • For short sequences or small images where full quadratic attention fits easily and yields slightly better accuracy.
  • When very tight, deterministic latency bounds are required and the implementation has not been production-optimized.

Failure Modes

  • Too-aggressive downsampling (small p or large r) can hurt accuracy by over-compressing long-range detail.
  • FMA-linear may lose some accuracy compared to FMA (log-linear) as reported in language experiments.

Core Entities

Models

  • FMA
  • FMA-linear
  • FMA2D-B
  • FMA2D-L
  • FMA2D-B-SegFormer
  • FMA2D-L-SegFormer
  • H-Transformer-1D
  • Reformer
  • Linear Transformer
  • Swin Transformer
  • ViT
  • SegFormer

Metrics

  • bits-per-character (bpc)
  • perplexity (ppl)
  • Accuracy
  • mIoU (%)
  • peak GPU memory (GB)

Datasets

  • enwik8
  • WikiText-103
  • ImageNet-1K
  • ImageNet-22K
  • ADE20K

Benchmarks

  • character-level language modeling
  • masked language modeling
  • ImageNet classification
  • ADE20K semantic segmentation