Compress ViT with GPU-friendly 2:4 sparsity + quantization to cut size/FLOPs and speed up real GPU inference

May 18, 20237 min

Overview

Production Readiness

0.7

Novelty Score

0.6

Cost Impact Score

0.8

Citation Count

0

Authors

Chong Yu, Tao Chen, Zhongxue Gan, Jiayuan Fan

Links

Abstract / PDF

Why It Matters For Business

GPUSQ-ViT cuts model size and compute by an order of magnitude while delivering real GPU speedups; this reduces cloud/GPU costs, eases edge deployment, and preserves accuracy on standard vision tasks.

Summary TLDR

GPUSQ-ViT applies GPU-native 2:4 fine-grained structured pruning plus quantization-aware training (INT8/INT4) with knowledge distillation to Vision Transformers. On ImageNet/COCO/ADE20K it reduces model size by 6.4–12.7× and FLOPs by ~30–62× with minimal accuracy loss, and yields 1.3–1.8× latency and 2–3.4× throughput improvements on NVIDIA A100 and AGX Orin when using TensorRT sparse kernels.

Problem Statement

Vision Transformers are large and rely heavily on matrix multiplications (GEMMs). Common pruning/quantization methods reduce FLOPs or size but often produce unstructured sparsity or exotic bit-widths that give little real GPU speedup. The paper targets practical, GPU-accelerated compression that matches NVIDIA Tensor Core 2:4 sparse support and common low-precision formats.

Main Contribution

Design of GPUSQ-ViT: combine 2:4 GPU-supported structured pruning with sparse-aware QAT and knowledge distillation.

A sparse-distillation-aware QAT that weights feature distillation by layer importance to reduce quantization error impact.

Demonstrated broad applicability: classification, detection, segmentation and unsupervised distillation, with real A100 and Orin speedups.

Key Findings

Model size cut using GPU-friendly compression

Numbers6.4–12.7× reduction in Params (Table 1,3,4)

Compute (FLOPs) reduction from pruning+quantization

Numbers30.3–62× reduction in FLOPs on evaluated models (Table 1,3,4)

Real deployment speedup on NVIDIA A100 GPU

Numbers1.39–1.79× lower latency and 3.22–3.43× higher throughput for INT4 (Table 2)

Real deployment speedup on AGX Orin edge device

Numbers1.57–1.69× latency and 2.11–2.51× throughput improvement for INT4 (Table 2)

Accuracy after aggressive compression is preserved

NumbersTop-1 drops often ≤1% or within ±0.5% on ImageNet / mAP / mIoU (Tables 1–4)

Results

Params reduction

Value6.4–12.7× smaller

BaselineDense FP32 models

FLOPs reduction

Value30.3–62× smaller

BaselineDense FP32 models

Accuracy

Valuetypically within ±1%

BaselineDense FP32 models

A100 latency (batch=1) INT4 vs FP32

Value1.39–1.79× lower latency (higher FPS)

BaselineFP32

A100 throughput (large batch) INT4 vs FP32

Value3.22–3.43× higher throughput

BaselineFP32

AGX Orin latency & throughput INT4 vs FP32

Value1.57–1.69× latency, 2.11–2.51× throughput

BaselineFP32

Who Should Care

What To Try In 7 Days

Run baseline ViT on your A100/Orin and measure FP32 latency/throughput with TensorRT.

Apply 2:4 structured pruning to linear layers (Q/K/V, projections, FFN) using available training or a small fine-tune set.

Fine-tune with quantization-aware training to INT8/INT4 using feature-based KD from original model and test accuracy trade-offs on a validation split.

Optimization Features

Infra Optimization

  • Measured on NVIDIA A100 and Jetson AGX Orin

Model Optimization

  • 2:4 fine-grained structured pruning
  • INT8 and INT4 quantized weights

System Optimization

  • Match compression pattern to GPU hardware (2:4)

Training Optimization

  • Quantization Aware Training (QAT)
  • Knowledge Distillation (hard label, soft logits, feature-based)

Inference Optimization

  • Sparse GEMM on NVIDIA Tensor Cores
  • TensorRT sparse kernels

Reproducibility

Data Urls

  • ImageNet-1K (public)
  • COCO (public)
  • ADE20K (public)

Data Available

Open Source Status

  • partial

Risks & Boundaries

Limitations

  • Requires hardware and runtime that support 2:4 structured sparsity (NVIDIA Ampere or newer and TensorRT).
  • Needs access to training/fine-tuning data for 2:4 pruning and QAT; PTQ-only scenarios may not reach same accuracy.
  • Tied to the 2:4 pattern; if hardware supports different sparse patterns, method must be adapted.

When Not To Use

  • Your target hardware lacks 2:4 sparse Tensor Core support.
  • You cannot fine-tune with representative data (no access to training or calibration set).
  • You require unstructured sparsity or highly irregular pruning patterns for accuracy.

Failure Modes

  • INT4 models may incur larger accuracy drops if distillation or layer-weighting is disabled (ablation shows sensitivity).
  • If runtime or drivers lack optimized sparse kernels, theoretical FLOPs reduction won't translate to speedup.
  • Using a dense FP32 teacher for QAT may reduce distillation effectiveness when formats differ.

Core Entities

Models

  • DeiT
  • Swin Transformer
  • Mask R-CNN
  • DETR
  • Deformable-DETR
  • UPerNet

Metrics

  • Params
  • FLOPs
  • Top-1 Acc
  • Top-5 Acc
  • Latency (FPS)
  • Throughput (FPS)
  • bbox mAP
  • segm mAP
  • Mean IoU
  • Accuracy

Datasets

  • ImageNet-1K
  • COCO
  • ADE20K

Benchmarks

  • Image classification
  • Object detection
  • Semantic segmentation