INT4 (4-bit) gives big latency wins for encoder models with little accuracy loss, but breaks decoder-only generators; optimized INT4 kernels

January 27, 20238 min

Overview

Production Readiness

0.7

Novelty Score

0.4

Cost Impact Score

0.8

Citation Count

6

Authors

Xiaoxia Wu, Cheng Li, Reza Yazdani Aminabadi, Zhewei Yao, Yuxiong He

Links

Abstract / PDF

Why It Matters For Business

On Ampere GPUs, INT4 computation can sharply reduce latency and cost for encoder-based workloads (search, classification, embedding). But it is risky to use for autoregressive generation (chatbots, text generation) until activation-quantization problems are solved.

Summary TLDR

The paper shows 4-bit weight+activation (W4A4) quantization can keep accuracy for encoder (BERT) and encoder-decoder (BART) models while enabling large latency wins with optimized GPU kernels. Decoder-only models (GPT) suffer large quality loss from 4-bit activation quantization. The authors release a tuned INT4 encoder inference pipeline (CUTLASS-based) that achieves up to 8.5× latency speedup over FP16 and improves prior INT8 performance by up to 1.7×.

Problem Statement

Can full INT4 computation (weights and activations) be used for transformer inference to double hardware throughput and reduce latency, without unacceptable quality loss? And how to implement fast, end-to-end INT4 inference on GPUs?

Main Contribution

System: an end-to-end, highly optimized INT4 encoder inference pipeline (CUTLASS kernels, fused quant/dequant, FlashAttention, CUDA graph).

Empirical: broad QAT+KD study of W4A4 across model types showing encoder and encoder-decoder models tolerate W4A4, decoder-only models do not.

Analysis and composability: root-cause analysis for decoder failures (activation range, attention, layer-Norm role) and compatibility tests combining INT4 with pruning and layer reduction.

Key Findings

Encoder models (BERT) keep accuracy under W4A4 QAT+KD.

NumbersBERT-base MNLI 84.20 (FP32) → 84.31 (W4A4 symmetric)

Encoder-decoder models (BART) show only small quality drops under W4A4.

NumbersBART-base RLsum 42.87 (FP32) → 41.92 (W4A4 symmetric), drop ≤1 point

Decoder-only models (GPT2) degrade substantially when activations are quantized to INT4.

NumbersGPT2-base PTB PPL 19.31 (FP32) → 22.17 (W4A4 symmetric); early-token positional PPL gap >100 on generation

End-to-end INT4 encoder inference can give large latency and throughput gains on Ampere GPUs.

NumbersUp to 8.5× latency speedup and up to 3× throughput over HuggingFace FP16; up to 1.7× over FasterTransformer INT8

INT4 GEMM speedup varies by layer shape; some GEMMs nearly 2× faster vs INT8 while others are smaller.

NumbersAt bs×seq=12288 (BERT-large): MLP output GEMM 1.96× vs INT8, attention output GEMM 1.46×

INT4 composes well with moderate pruning and layer reduction for encoder models with limited accuracy loss.

NumbersW4A4 + 50% Pair-(2:4) sparsity: ~0.5 GLUE points drop; 75% sparsity causes 0.79/1.6 drop on MNLI m/mm

Results

Accuracy

Value84.20 (FP32) → 84.31 (W4A4 symmetric)

BaselineFP32

BART-base Rouge Lsum

Value42.87 (FP32) → 41.92 (W4A4 symmetric)

BaselineFP32

GPT2-base perplexity (PTB)

Value19.31 (FP32) → 22.17 (W4A4 symmetric)

BaselineFP32

E2E latency speedup (encoder)

ValueUp to 8.5× (latency) and up to 3× (throughput)

BaselineHuggingFace FP16

Improvement over FasterTransformer INT8 (BERT)

ValueUp to 1.7× faster

BaselineFasterTransformer INT8

Who Should Care

What To Try In 7 Days

Benchmark W4A4 encoder inference on a representative bs×seq using the authors' INT4 pipeline or CUTLASS kernels.

If using BERT/BART for classification or summarization, run QAT+KD with W4A4 on a held-out dataset and measure quality vs FP16.

Do not quantize GPT activations to INT4 yet; test weight-only 4-bit quantization (w4) or mixed W4A8 as a safer step.

Optimization Features

Token Efficiency

  • Token-wise dynamic activation quantization (min/max per token)

Infra Optimization

  • Targeted to NVIDIA Ampere GPUs (A6000); relies on CUTLASS INT4 support

Model Optimization

  • W4A4 quantization (weights+activations) via QAT+KD
  • group-wise row quantization for weights

System Optimization

  • Pre-tuned GEMM schedules with CUTLASS profiler
  • Packing INT4 into INT8 tensors for current PyTorch support

Training Optimization

  • Quantization-aware training with knowledge distillation
  • exhaustive hyperparameter search per model

Inference Optimization

  • Custom CUTLASS INT4 GEMM kernels
  • Fused quantize/dequantize kernels to avoid extra memory traffic
  • FlashAttention integration for FP16 attention
  • CUDA graph to reduce kernel launch overhead
  • Per-GEMM tunable quantization strategy (modular enable/disable)

Reproducibility

Data Urls

  • Public datasets (MNLI, QQP, CNNDailyMail, XSum, PTB, Wikitext-2/103)

Code Available

Data Available

Open Source Status

  • partial

Risks & Boundaries

Limitations

  • W4A4 fails or degrades decoding/generation (GPT) due to activation quantization sensitivity.
  • Results target NVIDIA Ampere GPUs and CUTLASS INT4; other hardware may not match gains.
  • No end-to-end INT4 performance numbers reported for decoder-only models (study focuses on encoders).
  • Approach builds on existing QAT+KD techniques; novelty is in engineering and end-to-end evaluation.

When Not To Use

  • Autoregressive text generation (GPT-style) where generation quality matters.
  • Non-Ampere GPUs or hardware without efficient INT4 support.
  • Very small bs×seq shapes where quant/dequant overhead dominates and E2E speedups disappear.

Failure Modes

  • Activation quantization causes large early-token perplexity spikes in GPT (positional PPL gap >100 on early tokens).
  • Pretrained models can have wider activation ranges, making quantization harder than training-from-scratch.
  • Layer-normalization variant (Pre-LN vs Post-LN) is not a single fix for decoder failures; sensitivity persists.

Core Entities

Models

  • BERT-base
  • BERT-large
  • BART-base
  • BART-large
  • GPT2-base
  • GPT2-medium

Metrics

  • Accuracy
  • F1
  • Rouge Lsum
  • Perplexity
  • Latency speedup

Datasets

  • MNLI
  • QQP
  • CNNDailyMail
  • XSum
  • PTB
  • Wikitext-2
  • Wikitext-103
  • GLUE

Benchmarks

  • GLUE
  • Summarization (CNNDailyMail, XSum)
  • Causal language modeling (PTB, Wikitext)