Scale distillation by token confidence to train ternary-weight generative LMs with <1.0 PPL hit

August 13, 20237 min

Overview

Production Readiness

0.7

Novelty Score

0.6

Cost Impact Score

0.8

Citation Count

2

Authors

Minsoo Kim, Sihwa Lee, Janghwan Lee, Sukjin Hong, Du-Seong Chang, Wonyong Sung, Jungwook Choi

Links

Abstract / PDF

Why It Matters For Business

TSLD lets you quantize decoder LMs to 2-bit ternary weights with near full-precision quality and little extra training cost, reducing model size and inference memory while preserving reasoning accuracy.

Summary TLDR

The paper introduces Token-Scaled Logit Distillation (TSLD): a simple, memory-light distillation method that weights logit distillation per token by the teacher's token cross-entropy. TSLD enables quantization-aware training (QAT) down to ternary (2-bit) weights for decoder language models (GPT-2, OPT, LLaMA, GPT-Neo). On evaluated tasks, ternary QAT with TSLD keeps perplexity within ~1.0 of full-precision and improves or matches reasoning and NLU accuracy versus other QAT/PTQ baselines, while adding almost no training overhead compared to plain logit distillation.

Problem Statement

Generative decoder models suffer uneven, cumulative quantization error (masked causal attention) and overfitting when combining logit distillation with ground-truth loss. Existing QAT or PTQ either degrade perplexity or need high memory (layer-to-layer KD). The paper asks: can a light-weight distillation change avoid overfitting and recover token predictions for ternary-weight GLMs?

Main Contribution

Token-Scaled Logit Distillation (TSLD): scale logit KD per token by teacher token cross-entropy to reduce overfitting and emphasize uncertain tokens.

First large-scale evaluation of ternary-weight (2-bit) QAT on decoder GLMs up to ~7B parameters with <1.0 PPL degradation on evaluated benchmarks.

Empirical analysis: shows causal attention accumulates quantization error toward later tokens and demonstrates why logit KD (not L2L) better restores final token logits.

Implementation notes: memory-efficient QAT pipeline (PyTorch Pipe) and custom low-bit CUDA kernels for faster inference.

Key Findings

TSLD keeps PPL degradation under 1.0 vs full-precision on evaluated models with ternary weights.

NumbersOPT-6.7B PTB: FP16 PPL 10.21 → TSLD PPL 11.00 (+0.79)

TSLD improves downstream reasoning and QA accuracy compared to plain logit distillation.

NumbersOPT-2.7B ARC_challenge ACC: Logit 31.91 → TSLD 33.45 (+1.54)

Naively combining logit KD with ground-truth loss causes overfitting; token scaling fixes it.

NumbersGPT-2 0.1B PTB: Logit+GT PPL 21.51 vs TSLD 19.95 (better and less overfit)

TSLD adds negligible memory or speed overhead compared to plain logit distillation.

NumbersGPT2-0.3B iter/sec 1.57 and GPU memory 22622 MiB for both Logit and TSLD

Results

Perplexity (PTB)

ValueOPT-6.7B TSLD 11.00

BaselineOPT-6.7B FP16 10.21

Perplexity (PTB)

ValueGPT-2 0.1B TSLD 19.95

BaselineGPT-2 0.1B FP16 20.91

Accuracy

ValueOPT-2.7B TSLD 75.62

BaselineOPT-2.7B FP16 76.71

Accuracy

ValueOPT-2.7B TSLD 33.45

BaselineOPT-2.7B Logit 31.91

Training overhead

ValueTSLD ≈ Logit KD

BaselineLogit KD

Who Should Care

What To Try In 7 Days

Run TSLD QAT on a task-fine-tuned 1–7B decoder model to test ternary-weight inference.

Compare PPL and task accuracy vs your current FP and PTQ baselines on a small validation set.

If using logit KD+GT and seeing overfitting, replace logit KD with TSLD (token-scaled weights).

Optimization Features

Token Efficiency

  • per-token distillation scaling by teacher cross-entropy

Infra Optimization

  • A100 GPU experiments; kernels target A100-style hardware

Model Optimization

  • ternary weight quantization (2-bit, ±1/0)
  • 4-bit quantization evaluation

System Optimization

  • pipeline parallelism via PyTorch Pipe API

Training Optimization

  • quantization-aware training (QAT)
  • token-scaled logit distillation (TSLD) for KD

Inference Optimization

  • custom low-bit CUDA kernels for 2/4/8-bit matrix multiply
  • weight packing to reduce load overhead

Reproducibility

Data Urls

  • Penn Treebank (PTB): public
  • PIQA, OpenbookQA, ARC, GSM8K, GLUE: public

Code Available

Data Available

Open Source Status

  • partial

Risks & Boundaries

Limitations

  • Experiments use A100 GPUs and pipeline parallelism; smaller infra may need engineering.
  • L2L KD still outperforms in encoder models; TSLD targets decoder GLMs specifically.
  • Accumulation of token quantization error can vary by depth and sequence length; effect weakens in some deep/long cases.

When Not To Use

  • If you need exact full-precision behavior for every example (sensitive safety-critical outputs).
  • If you cannot run QAT or lack the GPUs for teacher-student training.

Failure Modes

  • Using plain logit KD + GT naively can cause overfitting and worse eval loss.
  • L2L KD can run out of GPU memory for models >1.3B on 40GB GPUs.
  • Incorrect initialization of clipping scales (QuantGPT style) can clip many weights and harm 4-bit QAT.

Core Entities

Models

  • GPT-2
  • OPT
  • LLaMA
  • GPT-Neo

Metrics

  • Perplexity (PPL)
  • Accuracy

Datasets

  • Penn Treebank (PTB)
  • PIQA
  • OpenbookQA
  • ARC_easy
  • ARC_challenge
  • GSM8K
  • GLUE

Benchmarks

  • PTB language modeling
  • Commonsense QA (PIQA, OpenbookQA, ARC)
  • GSM8K arithmetic reasoning
  • GLUE NLU