GTSP: prune tokens, heads, layers, and weights to cut Graph Transformer compute with little or no accuracy loss

December 9, 20237 min

Overview

Production Readiness

0.6

Novelty Score

0.5

Cost Impact Score

0.7

Citation Count

0

Authors

Chuang Liu, Yibing Zhan, Xueqi Ma, Liang Ding, Dapeng Tao, Jia Wu, Wenbin Hu, Bo Du

Links

Abstract / PDF

Why It Matters For Business

GTSP can reduce Graph Transformer compute and memory by tens of percent while keeping or improving accuracy on evaluated benchmarks, enabling cheaper training and deployment on constrained hardware.

Summary TLDR

The paper introduces GTSP, a practical mask-based framework that sparsifies Graph Transformers along four axes: input tokens (nodes), attention heads, layers, and weights. GTSP uses learnable masks, Gumbel-softmax top-k selection for tokens, gradient-based importance for heads, stochastic layer dropping, gradual magnitude pruning plus regrowth for weights. On GraphTrans/Graphormer/GraphGPS across NCI1, OGBG-HIV, and OGBG-Molpcba, GTSP cuts FLOPs by ~30–47% while keeping accuracy similar and sometimes improving it (e.g., ROC-AUC 0.7633→0.7773 on OGBG-HIV for 50% weight sparsity). The paper studies each pruning axis separately and warns that joint pruning and per-graph ratio tuning remain open

Problem Statement

Graph Transformers match or exceed GNNs but cost much more compute and memory because of multi-head self-attention. Existing GNN pruning methods target edges or channels and node-classification tasks; they don't directly transfer to Graph Transformers used for graph-level tasks. The paper asks: can we safely sparsify Graph Transformers across tokens, heads, layers, and weights to cut compute while keeping performance?

Main Contribution

A systematic analysis of redundancy in Graph Transformers across four components: input tokens, attention heads, layers, and weights.

GTSP: a unified, mask-based sparsification framework with differentiable token selection, head importance scoring, stochastic layer dropping, and gradual weight pruning with regrowth.

Extensive experiments on three GT architectures (GraphTrans, Graphormer, GraphGPS) and three datasets showing large FLOPs/parameter savings with small accuracy impact and occasional accuracy gains.

Key Findings

Weight pruning (50% sparsity) can increase AUC on OGBG-HIV while cutting compute.

NumbersROC-AUC 0.7633 → 0.7773 (+0.014); FLOPs −30.2%

Token pruning gives the largest FLOPs cuts but risks accuracy loss on small graphs.

NumbersFLOPs −47.4% on NCI1; accuracy 83.71 → 82.77 (−0.94 pts)

Many attention heads and some layers are redundant; pruning them often keeps accuracy.

NumbersLayer halving: params −32.9% to −45.6%, FLOPs −46.6% to −48.8%, accuracy drop 0.2–3.3%

Results

OGBG-HIV ROC-AUC

ValueBaseline 0.7633 → GTSP-WP (50%) 0.7773

Baseline0.7633

NCI1 FLOPs reduction (token pruning)

ValueFLOPs −47.4%

Baseline0% FLOPs saving

Parameters and FLOPs when halving layers

ValueParams −32.9% to −45.6%; FLOPs −46.6% to −48.8%

Baselinefull depth models

Who Should Care

What To Try In 7 Days

Run 40–50% weight magnitude pruning with regrowth on your Graph Transformer; compare ROC-AUC and FLOPs.

If graphs are large, test token (node) pruning with top-k selection and validate accuracy on a holdout set.

Try stochastic layer dropping to halve depth and measure inference latency and accuracy trade-offs.

Optimization Features

Token Efficiency

  • learnable token selector using GCN scores
  • Gumbel-softmax + straight-through top-k for differentiability
  • score perturbation to avoid local selection bias

Model Optimization

  • weight pruning (magnitude-based, gradual)
  • attention-head pruning with importance scores
  • layer dropping (stochastic drop)

System Optimization

  • element-wise mask multiplications added; small overhead vs attention savings

Training Optimization

  • sparse training with regrowth (gradient-based)
  • gradual magnitude pruning schedule

Inference Optimization

  • token (node) selection to shorten attention length
  • remove attention heads at inference

Reproducibility

Code Available

Data Available

Open Source Status

  • partial

Risks & Boundaries

Limitations

  • Paper focuses on pruning each component separately; joint pruning interactions are not evaluated.
  • Token pruning can harm accuracy on small graphs (20–30 nodes); per-dataset tuning is needed.
  • No latency or energy-measure experiments on real hardware were reported.
  • Regrowth and mask hyperparameters require tuning per model and dataset.

When Not To Use

  • On small graphs where removing tokens removes crucial nodes.
  • When you need an out-of-the-box joint pruning recipe — paper treats axes independently.
  • If you cannot retrain or fine-tune after pruning.

Failure Modes

  • Over-pruning can drop accuracy sharply if important heads or nodes are removed.
  • Token selector may get stuck in local structures without proper score perturbation.
  • Premature pruning without regrowth can lose useful parameters early in training.

Core Entities

Models

  • GraphTrans
  • Graphormer
  • GraphGPS
  • Graph Transformer (general)

Metrics

  • Accuracy
  • ROC-AUC
  • Number of parameters
  • FLOPs

Datasets

  • NCI1
  • OGBG-HIV
  • OGBG-Molpcba
  • Open Graph Benchmark (OGB)

Benchmarks

  • graph classification (NCI1, OGBG-HIV, OGBG-Molpcba)