Use per-token unstructured pruning + a bitmap sparse kernel to cut KV cache to ~45% size and speed decoding up to 2.23×

May 28, 20258 min

Overview

Production Readiness

0.7

Novelty Score

0.7

Cost Impact Score

0.7

Citation Count

0

Authors

Donghyeon Joo, Helya Hosseini, Ramyad Hadidi, Bahar Asgari

Links

Abstract / PDF

Why It Matters For Business

Mustafar cuts KV memory use and can double token throughput on batch-friendly workloads, enabling longer contexts or lower cloud costs without model fine-tuning.

Summary TLDR

Mustafar shows that element-wise (unstructured) pruning applied per token compresses both Key and Value caches to high sparsity with little accuracy loss. The paper pairs simple per-token magnitude pruning with a bitmap compressed format and a custom sparse attention CUDA kernel that computes directly on compressed KV caches. On tests (LongBench, RULER) Mustafar preserves accuracy at 50% and often at 70% sparsity, reduces KV memory to as low as 45% of dense at 70% sparsity, and raises tokens/sec up to 2.23× (Llama-3-8B, batch 8). Code is available.

Problem Statement

KV cache size is the main memory bottleneck for long-context decoding. We need a pruning and runtime strategy that (1) removes a large fraction of KV elements without breaking task accuracy, and (2) compresses and computes over the resulting arbitrary sparsity efficiently enough that overall latency improves.

Main Contribution

Show per-token magnitude-based unstructured pruning preserves accuracy better than structured pruning for both Key and Value caches

Introduce a bitmap-based compressed KV format and a custom CUDA sparse attention kernel that runs directly on compressed caches

Demonstrate practical gains: KV cache compressed to ~45% of dense at 70% sparsity and up to 2.23× tokens/sec vs dense on Llama-3-8B (batch 8)

Show compatibility with orthogonal methods (token eviction H2O, quantization KIVI) and provide open-source implementation

Key Findings

Per-token magnitude-based unstructured pruning retains accuracy far better than structured pruning across LongBench

NumbersLlama-3-8B LongBench avg: dense 43.19 vs K0.5 V0.5 42.65 (Δ −0.54)

Mustafar compresses KV cache to as low as 45% of dense at 70% joint Key+Value sparsity

NumbersKV cache compression ratio 45% at K0.7 V0.7

End-to-end decode throughput can improve up to 2.23× vs dense inference for Llama-3-8B (batch size 8)

Numbers2.23× tokens/sec (Llama-3, batch 8)

Runtime pruning and compression overhead is small and outweighed by sparse SpMV speedup

NumbersLlama-2: pruning 1.84% + compression 6.25% of cuBLAS time; SpMV takes 81.07% (50%) or 61.87% (70%) of cuBLAS time

Joint use with quantization and token eviction keeps accuracy at 50% but degrades at 70% on some tasks

NumbersKIVI+Mustafar: 50% pruning retains task scores; 70% shows larger drops, especially summarization

Results

LongBench average (Llama-3-8B-Instruct)

ValueDense 43.19; K0.5 V0.5 42.65

BaselineDense model

KV cache compression ratio

Value45% of dense

BaselineDense KV cache (100%)

Throughput (tokens/sec)

Valueup to 2.23× vs dense

BaselineDense inference (FlashAttention)

Runtime overhead (prune + compress)

Valuepruning 1.84% + compression 6.25% of cuBLAS time

BaselinecuBLAS dense MV time

Who Should Care

What To Try In 7 Days

Clone the repo and run the provided kernel on a small model (Llama-2-7B) to reproduce K0.5 V0.5 throughput

Measure LongBench/one task accuracy at 50% KV sparsity; use that as a conservative production setting

Combine Mustafar with your existing 4-bit quant pipeline and validate end-to-end latency and accuracy on a representative workload (start with K0.5 V0.0 then K0.5 V0.5)

Optimization Features

Token Efficiency

  • compatible with token eviction (H2O)
  • local dense window of last 32 tokens

Infra Optimization

  • optimizes global memory traffic to SMs on NVIDIA GPUs
  • works best with batch sizes that saturate SMs (batch≥4)

Model Optimization

  • per-token magnitude-based unstructured pruning
  • compatibility with quantization (KIVI)

System Optimization

  • Triton GPU compression kernel
  • tile-wise shared-memory decompression
  • warp-thread 1×64 thread-tile layout

Inference Optimization

  • bitmap-based compressed KV format
  • custom CUDA sparse attention kernel (SpMV on compressed tiles)
  • load-as-compressed, compute-as-dense pipeline

Reproducibility

Code Available

Open Source Status

  • yes

Risks & Boundaries

Limitations

  • Kernel currently does not support low-bit precision compute on compressed tiles
  • Small batch sizes (batch=1) can be slower due to underutilized GPU
  • Some tasks and models show accuracy drops at 70%+ key sparsity; per-layer or per-head sparsity not explored
  • Time-to-first-token increases due to prefill pruning/compression overhead

When Not To Use

  • Latency-sensitive single-request workloads (batch size 1)
  • Models or tasks where key cache is highly sensitive to element removal at high sparsity
  • Environments requiring immediate low-bit compute on compressed KV tiles (not supported yet)

Failure Modes

  • Large accuracy loss when applying >70% key sparsity on models sensitive to key magnitudes
  • Throughput drop for small batches due to GPU underutilization
  • Excessive prefill delay if pruning/compression is not amortized by long decode

Core Entities

Models

  • Llama-3-8B-Instruct
  • Llama-2-7B
  • Mistral-7B-Instruct-v0.2
  • Llama-2-13B-chat
  • Llama-3.1-8B-Instruct

Metrics

  • LongBench average score
  • tokens/sec (throughput)
  • KV cache compression ratio (percent of dense)
  • kernel latency breakdown (cuBLAS normalized)

Datasets

  • LongBench
  • RULER

Benchmarks

  • LongBench
  • RULER