Overview
Production Readiness
0.7
Novelty Score
0.6
Cost Impact Score
0.8
Citation Count
2
Why It Matters For Business
TSLD lets you quantize decoder LMs to 2-bit ternary weights with near full-precision quality and little extra training cost, reducing model size and inference memory while preserving reasoning accuracy.
Summary TLDR
The paper introduces Token-Scaled Logit Distillation (TSLD): a simple, memory-light distillation method that weights logit distillation per token by the teacher's token cross-entropy. TSLD enables quantization-aware training (QAT) down to ternary (2-bit) weights for decoder language models (GPT-2, OPT, LLaMA, GPT-Neo). On evaluated tasks, ternary QAT with TSLD keeps perplexity within ~1.0 of full-precision and improves or matches reasoning and NLU accuracy versus other QAT/PTQ baselines, while adding almost no training overhead compared to plain logit distillation.
Problem Statement
Generative decoder models suffer uneven, cumulative quantization error (masked causal attention) and overfitting when combining logit distillation with ground-truth loss. Existing QAT or PTQ either degrade perplexity or need high memory (layer-to-layer KD). The paper asks: can a light-weight distillation change avoid overfitting and recover token predictions for ternary-weight GLMs?
Main Contribution
Token-Scaled Logit Distillation (TSLD): scale logit KD per token by teacher token cross-entropy to reduce overfitting and emphasize uncertain tokens.
First large-scale evaluation of ternary-weight (2-bit) QAT on decoder GLMs up to ~7B parameters with <1.0 PPL degradation on evaluated benchmarks.
Empirical analysis: shows causal attention accumulates quantization error toward later tokens and demonstrates why logit KD (not L2L) better restores final token logits.
Implementation notes: memory-efficient QAT pipeline (PyTorch Pipe) and custom low-bit CUDA kernels for faster inference.
Key Findings
TSLD keeps PPL degradation under 1.0 vs full-precision on evaluated models with ternary weights.
TSLD improves downstream reasoning and QA accuracy compared to plain logit distillation.
Naively combining logit KD with ground-truth loss causes overfitting; token scaling fixes it.
TSLD adds negligible memory or speed overhead compared to plain logit distillation.
Results
Perplexity (PTB)
Perplexity (PTB)
Accuracy
Accuracy
Training overhead
Who Should Care
What To Try In 7 Days
Run TSLD QAT on a task-fine-tuned 1–7B decoder model to test ternary-weight inference.
Compare PPL and task accuracy vs your current FP and PTQ baselines on a small validation set.
If using logit KD+GT and seeing overfitting, replace logit KD with TSLD (token-scaled weights).
Optimization Features
Token Efficiency
- per-token distillation scaling by teacher cross-entropy
Infra Optimization
- A100 GPU experiments; kernels target A100-style hardware
Model Optimization
- ternary weight quantization (2-bit, ±1/0)
- 4-bit quantization evaluation
System Optimization
- pipeline parallelism via PyTorch Pipe API
Training Optimization
- quantization-aware training (QAT)
- token-scaled logit distillation (TSLD) for KD
Inference Optimization
- custom low-bit CUDA kernels for 2/4/8-bit matrix multiply
- weight packing to reduce load overhead
Reproducibility
Code Urls
Data Urls
- Penn Treebank (PTB): public
- PIQA, OpenbookQA, ARC, GSM8K, GLUE: public
Code Available
Data Available
Open Source Status
- partial
Risks & Boundaries
Limitations
- Experiments use A100 GPUs and pipeline parallelism; smaller infra may need engineering.
- L2L KD still outperforms in encoder models; TSLD targets decoder GLMs specifically.
- Accumulation of token quantization error can vary by depth and sequence length; effect weakens in some deep/long cases.
When Not To Use
- If you need exact full-precision behavior for every example (sensitive safety-critical outputs).
- If you cannot run QAT or lack the GPUs for teacher-student training.
Failure Modes
- Using plain logit KD + GT naively can cause overfitting and worse eval loss.
- L2L KD can run out of GPU memory for models >1.3B on 40GB GPUs.
- Incorrect initialization of clipping scales (QuantGPT style) can clip many weights and harm 4-bit QAT.
Core Entities
Models
- GPT-2
- OPT
- LLaMA
- GPT-Neo
Metrics
- Perplexity (PPL)
- Accuracy
Datasets
- Penn Treebank (PTB)
- PIQA
- OpenbookQA
- ARC_easy
- ARC_challenge
- GSM8K
- GLUE
Benchmarks
- PTB language modeling
- Commonsense QA (PIQA, OpenbookQA, ARC)
- GSM8K arithmetic reasoning
- GLUE NLU

