A linear-attention LLM that matches or beats Transformers while running faster and using less memory

July 27, 20239 min

Overview

Production Readiness

0.6

Novelty Score

0.7

Cost Impact Score

0.8

Citation Count

5

Authors

Zhen Qin, Dong Li, Weigao Sun, Weixuan Sun, Xuyang Shen, Xiaodong Han, Yunshen Wei, Baohong Lv, Xiao Luo, Yu Qiao, Yiran Zhong

Links

Abstract / PDF

Why It Matters For Business

TransNormerLLM can lower compute and memory needs for long-context LLM training and serving while keeping or improving accuracy, letting teams run larger contexts or reduce hardware costs without sacrificing model quality.

Summary TLDR

TransNormerLLM is an improved linear-attention LLM that combines LRPE-d positional encoding, Lightning Attention (an IO-aware blocked algorithm), gating (SGLU/GLA), and a simple RMS normalization to deliver similar or better accuracy than Transformer LLMs while cutting runtime and memory. The authors train 385M, 1B and 7B models on a proprietary 6 TB (≈2T token) corpus and report up to 2× faster attention, up to 4× lower attention memory, and lower perplexity than comparable Transformer baselines on evaluated benchmarks.

Problem Statement

Softmax attention gives good accuracy but costs O(n^2) time and memory with sequence length. Prior linear-attention variants either lose language modeling quality or fail to show real speed wins. This paper asks: can a linear-attention LLM match Transformer accuracy while improving runtime and memory in real training and inference?

Main Contribution

TransNormerLLM: a linear-attention LLM design that adds LRPE-d positional encoding, gating (GLA/SGLU), and SimpleRMSNorm to improve quality and stability.

Lightning Attention: an IO-aware blocked algorithm for training linear attention that reduces runtime and memory (blocks inputs to SRAM).

System recipe for scale: model-parallel scheme, FSDP, activation checkpointing, AMP/BFloat16, enabling training up to 175B parameters.

Robust inference algorithm and recurrent K⊤V updates that stabilize numeric behavior and keep inference cost roughly constant with context length.

Key Findings

TransNormerLLM yields lower perplexity than Transformer baselines at small and medium scales.

Numbers385M model: PPL 4.77 vs Transformer 5.16; 1B model: PPL 3.729 vs Transformer 4.765

Lightning Attention runs noticeably faster and uses much less memory than a PyTorch NormAttention baseline during training.

Numbers>=2× faster runtime; up to 4× lower attention memory at sequence length 8192

Tensor/normalization and design simplifications speed up the model.

NumbersSimple normalization plus changes reported to give >20% acceleration

Inference time and memory grow much less with sequence length compared to standard Transformer behavior.

NumbersInference runtime and memory remain consistent as sequence length increases (figures show near-constant behavior)

Results

Perplexity (385M)

ValueTransNormerLLM PPL 4.77 vs Transformer PPL 5.16

BaselineTransformer 385M

Perplexity (1B)

ValueTransNormerLLM PPL 3.729 vs Transformer PPL 4.765

BaselineTransformer 1B

Lightning Attention speed

Value>= 2× faster runtime for forward+backward

BaselinePyTorch NormAttention implementation

Lightning Attention memory

ValueUp to 4× lower attention memory at sequence length 8192

BaselinePyTorch NormAttention implementation

Inference scaling

ValueInference time and memory roughly constant across growing sequence lengths

BaselineTransformer behavior (grows with sequence length)

Benchmarks (7B aggregated)

ValueCompetitive accuracy versus top open-source 7B models (e.g., LLaMA, Baichuan)

BaselineMultiple open-source 7B models

Who Should Care

What To Try In 7 Days

Run the released TransNormerLLM code to profile Lightning Attention vs your attention kernel on a dev GPU.

Swap in SRMSNorm and SGLU in a small model to measure speed and validation loss differences.

Benchmark inference latency and memory with longer input contexts to see production gains.

Optimization Features

Token Efficiency

  • Trained with long context length (8192) and supports longer contexts in training and inference

Infra Optimization

  • A100 80G clusters (NVLink) tested
  • Triton and PyTorch implementations tuned for speed

Model Optimization

  • Linear attention (NormAttention form)
  • LRPE-d positional encoding
  • Gated Linear Attention (GLA) and Simple GLU (SGLU)
  • SRMSNorm (Simple RMSNorm)

System Optimization

  • Model-parallel split for GLA and SGLU (Megatron-style)
  • Triton kernels and SRAM blocking for Lightning Attention
  • IO-aware blocking to move work to on-chip SRAM

Training Optimization

  • Lightning Attention (IO-aware blocked attention)
  • Fully Sharded Data Parallel (FSDP)
  • Activation checkpointing
  • Automatic mixed precision / BFloat16

Inference Optimization

  • Robust recurrent K⊤V inference algorithm
  • Recurrent K⊤V updates (constant-time per token)
  • LRPE-d compatible with linear RNN-style inference

Reproducibility

Code Available

Open Source Status

  • partial

Risks & Boundaries

Limitations

  • Primary training corpus is proprietary (6 TB cleaned, ≈2T tokens) and is not released; replication may be hard.
  • Benchmarks and ablations are reported mainly up to 7B for accuracy comparisons; large-scale (175B) accuracy claims are not fully demonstrated.
  • Numerical instability arises if the decay λ is made learnable; authors fix λ to avoid NaNs.

When Not To Use

  • When you require models trained on open, standard corpora for strict comparability.
  • If your application depends on properties unique to softmax attention and existing Transformer pretraining checkpoints.

Failure Modes

  • Making the decay λ learnable can cause training NaNs and numerical instability.
  • Some activation choices (e.g., 1+elu) caused NaNs in their 7B runs, so activation changes can break stability at scale.
  • Lightning Attention requires careful IO-block tuning and hardware support (on-chip SRAM behavior matters).

Core Entities

Models

  • TransNormerLLM
  • TransNormer
  • Transformer
  • RWKV
  • Pythia
  • OPT
  • LLaMA
  • Falcon
  • Baichuan
  • ChatGLM
  • GPT-Neo
  • GPT-J
  • MPT

Metrics

  • Perplexity (PPL)
  • Validation Loss
  • Tokens/sec (throughput)
  • Inference runtime (ms)
  • Memory footprint (GB)
  • Accuracy

Datasets

  • Proprietary 6 TB cleaned corpus (~2T tokens)
  • MMLU
  • CMMLU
  • C-Eval
  • BoolQ
  • PIQA
  • HellaSwag
  • WinoGrande
  • ARC-e
  • ARC-c
  • OpenBookQA

Benchmarks

  • MMLU
  • CMMLU
  • C-Eval
  • Commonsense Reasoning (BoolQ, PIQA, HellaSwag, WinoGrande, ARC, OBQA)

Context Entities

Models

  • TransNormerLLM variants: 385M, 1B, 3B, 7B, 13B, 65B, 175B

Datasets

  • Corpus categories: Academic Writings, Books, Code, Encyclopedia, Filtered Webpages, Others