Scale distillation by token confidence to train ternary-weight generative LMs with <1.0 PPL hit

August 13, 20237 min

Overview

Decision SnapshotReady For Pilot

The method is a modest algorithmic change (token scaling) with clear empirical gains and low overhead; results are consistent across multiple GLMs and public tasks.

Citations2

Evidence Strength0.80

Confidence0.85

Risk Signals8

Trust Signals

Findings with numeric evidence: 4/4

Findings with evidence refs: 4/4

Results with explicit delta: 5/5

Reproducibility

Status: Code + data available

Open source: Partial

At A Glance

Cost impact: 80%

Production readiness: 70%

Novelty: 60%

Authors

Minsoo Kim, Sihwa Lee, Janghwan Lee, Sukjin Hong, Du-Seong Chang, Wonyong Sung, Jungwook Choi

Links

Abstract / PDF / Code / Data

Why It Matters For Business

TSLD lets you quantize decoder LMs to 2-bit ternary weights with near full-precision quality and little extra training cost, reducing model size and inference memory while preserving reasoning accuracy.

Who Should Care

Summary TLDR

The paper introduces Token-Scaled Logit Distillation (TSLD): a simple, memory-light distillation method that weights logit distillation per token by the teacher's token cross-entropy. TSLD enables quantization-aware training (QAT) down to ternary (2-bit) weights for decoder language models (GPT-2, OPT, LLaMA, GPT-Neo). On evaluated tasks, ternary QAT with TSLD keeps perplexity within ~1.0 of full-precision and improves or matches reasoning and NLU accuracy versus other QAT/PTQ baselines, while adding almost no training overhead compared to plain logit distillation.

Problem Statement

Generative decoder models suffer uneven, cumulative quantization error (masked causal attention) and overfitting when combining logit distillation with ground-truth loss. Existing QAT or PTQ either degrade perplexity or need high memory (layer-to-layer KD). The paper asks: can a light-weight distillation change avoid overfitting and recover token predictions for ternary-weight GLMs?

Main Contribution

Token-Scaled Logit Distillation (TSLD): scale logit KD per token by teacher token cross-entropy to reduce overfitting and emphasize uncertain tokens.

First large-scale evaluation of ternary-weight (2-bit) QAT on decoder GLMs up to ~7B parameters with <1.0 PPL degradation on evaluated benchmarks.

Key Findings

TSLD keeps PPL degradation under 1.0 vs full-precision on evaluated models with ternary weights.

NumbersOPT-6.7B PTB: FP16 PPL 10.21 → TSLD PPL 11.00 (+0.79)

Practical UseYou can deploy 2-bit ternary GLMs with TSLD and expect near-FP perplexity on language modeling benchmarks.

Evidence RefTable 1

TSLD improves downstream reasoning and QA accuracy compared to plain logit distillation.

NumbersOPT-2.7B ARC_challenge ACC: Logit 31.91 → TSLD 33.45 (+1.54)

Practical UseWhen accuracy on reasoning or QA matters, use TSLD instead of logit-only KD to avoid subtle reasoning failures.

Evidence RefTable 2 (top)

Results

MetricValueBaselineDeltaSplit / DatasetEvidenceEvidence Ref
Perplexity (PTB)OPT-6.7B TSLD 11.00OPT-6.7B FP16 10.21+0.79PTBTable 1 (PPL comparison)Table 1
Perplexity (PTB)GPT-2 0.1B TSLD 19.95GPT-2 0.1B FP16 20.91-0.96PTBTable 1 shows TSLD can even improve PPL for some small GPT-2 sizesTable 1

What To Try In 7 Days

Run TSLD QAT on a task-fine-tuned 1–7B decoder model to test ternary-weight inference.

Compare PPL and task accuracy vs your current FP and PTQ baselines on a small validation set.

If using logit KD+GT and seeing overfitting, replace logit KD with TSLD (token-scaled weights).

Optimization Features

Token Efficiency
per-token distillation scaling by teacher cross-entropy
Infra Optimization
A100 GPU experiments; kernels target A100-style hardware
Model Optimization
ternary weight quantization (2-bit, ±1/0)4-bit quantization evaluation
System Optimization
pipeline parallelism via PyTorch Pipe API
Training Optimization
quantization-aware training (QAT)token-scaled logit distillation (TSLD) for KD
Inference Optimization
custom low-bit CUDA kernels for 2/4/8-bit matrix multiplyweight packing to reduce load overhead

Reproducibility

Code AvailableYes
Data AvailableYes
Open Source StatusPartial
LicenseUnknown

Data URLs

Penn Treebank (PTB): publicPIQA, OpenbookQA, ARC, GSM8K, GLUE: public

Risks & Boundaries

Limitations

Experiments use A100 GPUs and pipeline parallelism; smaller infra may need engineering.

L2L KD still outperforms in encoder models; TSLD targets decoder GLMs specifically.

When Not To Use

If you need exact full-precision behavior for every example (sensitive safety-critical outputs).

If you cannot run QAT or lack the GPUs for teacher-student training.

Failure Modes

Using plain logit KD + GT naively can cause overfitting and worse eval loss.

L2L KD can run out of GPU memory for models >1.3B on 40GB GPUs.

Core Entities

Models

GPT-2OPTLLaMAGPT-Neo

Metrics

Perplexity (PPL)Accuracy

Datasets

Penn Treebank (PTB)PIQAOpenbookQAARC_easyARC_challengeGSM8KGLUE

Benchmarks

PTB language modelingCommonsense QA (PIQA, OpenbookQA, ARC)GSM8K arithmetic reasoningGLUE NLU