GRIFFIN: training-free sequence-level neuron selection that cuts FF work by 50% and speeds up generation

April 1, 20247 min

Overview

Production Readiness

0.7

Novelty Score

0.6

Cost Impact Score

0.7

Citation Count

0

Authors

Harry Dong, Beidi Chen, Yuejie Chi

Links

Abstract / PDF

Why It Matters For Business

GRIFFIN cuts active FF work by half with no retraining, offering real latency and memory wins for deploy-time generation while preserving most task quality on evaluated models and datasets.

Summary TLDR

GRIFFIN is a training-free method that picks which feed‑forward (FF) neurons to run per input sequence by looking at the prompt. It exploits a phenomenon called "flocking" (tokens in a sequence share relative neuron activations) to keep about half of FF neurons inactive during generation. On many models and tasks GRIFFIN preserves task quality at 50% FF sparsity while lowering latency (e.g., 1.29× on Gemma 7B, 1.25× on Llama 2 13B) and reducing active parameters (Llama 2 13B: 13B→8.8B). Code is public.

Problem Statement

Large transformer models waste a lot of compute in feedforward (FF) layers because many intermediate neurons contribute little per token. Existing fixes (structured pruning, MoEs) either require training, fail with non-ReLU activations, or are hard to deploy. The paper asks: can we adaptively skip FF neurons per sequence, without training, across many LLMs and activation types?

Main Contribution

Identify "flocking": within a sequence, relative FF neuron activations are highly consistent across tokens.

GRIFFIN: a training-free, prompt-based top-k selector that chooses FF neurons per sequence and reuses them during generation.

Show GRIFFIN works across many models (Llama 2, Gemma, Mistral, OPT, ReluLlama), activation functions, and tasks while halving active FF neurons and improving latency.

Key Findings

GRIFFIN keeps performance near the full model at 50% FF sparsity on classification tasks.

NumbersHellaSwag Llama 2 7B: 57.16 -> 57.11 accuracy (full -> GRIFFIN)

GRIFFIN preserves much of generation quality at 50% FF sparsity on summarization and QA.

NumbersXSum Rouge-1: Llama 2 7B 27.15 -> 24.75; Gemma 7B 26.86 -> 25.86

GRIFFIN reduces active parameters and improves long-generation latency.

NumbersLlama 2 13B active params 13B -> 8.8B; speedups up to 1.25× (Llama2 13B) and 1.29× (Gemma 7B)

Static neuron magnitude pruning fails for generation but GRIFFIN succeeds.

NumbersMany magnitude-pruned models show huge quality loss (e.g., Gemma 7B XSum Rouge-1 26.86 -> 1.49)

Results

Efficient Inference

ValueLlama2 13B: 13B -> 8.8B; Gemma 7B: 8.5B -> 5.4B

Baselinefull model

Latency (long generation)

ValueGemma 7B 1.29×, Llama 2 13B 1.25× speed-up

Baselinefull model

Summarization quality (XSum Rouge-1)

ValueLlama 2 7B 27.15 -> 24.75; Gemma 7B 26.86 -> 25.86

Baselinefull model

Accuracy

ValueLlama 2 7B 57.16 -> 57.11

Baselinefull model

Who Should Care

What To Try In 7 Days

Run GRIFFIN at 50% FF sparsity on one production LLM and measure latency, memory, and task metrics.

Compare prompt lengths: increase prompt size to reduce long-generation quality loss.

Test batched vs. single-sample throughput and confirm whether pruned model fits a single device to avoid offload.

Optimization Features

Token Efficiency

  • Lower compute per generated token due to fewer active neurons

Infra Optimization

  • Reduces memory footprint of FF layers during generation

Model Optimization

  • Sequence-level structured pruning (top-k neurons from prompt)
  • Adaptive per-sequence expert neuron selection (no training)

System Optimization

  • Enables fitting pruned model on single device to avoid offload
  • Best for single-sample, latency-sensitive inference

Inference Optimization

  • Reduces active FF dimensions during generation
  • Works with non-ReLU activations (SwiGLU, GEGLU, ReGLU)

Reproducibility

Code Available

Open Source Status

  • yes

Risks & Boundaries

Limitations

  • Performance degrades for very long generations when prompt is short; longer prompts help.
  • Benefits shrink as batch size grows; best suited to batch size 1 or small batches.
  • Requires full FF execution for the prompt phase, so prompt cost is unchanged.

When Not To Use

  • Workloads with extremely long uncontrolled generation and short prompts where neuron patterns drift.
  • High-throughput large-batch serving where batch-level aggregation reduces adaptivity gains.
  • When you cannot afford the prompt-phase full FF computation or need fully static models.

Failure Modes

  • Prompt is not representative and selected neurons misalign with later generation.
  • Sampling-based neuron selection (instead of top-k) substantially reduces quality.
  • Static magnitude pruning can catastrophically fail for generation tasks.

Core Entities

Models

  • Llama 2
  • Gemma
  • Mistral
  • OPT
  • ReluLlama

Metrics

  • Rouge-1/2/L
  • F1
  • ExactMatch
  • Accuracy
  • Latency (s)
  • Active parameter count

Datasets

  • WikiText
  • XSum
  • CNN/DailyMail
  • CoQA
  • QASPER
  • HellaSwag
  • PIQA
  • COPA
  • ARC-e
  • ARC-c
  • BoolQ

Benchmarks

  • XSum
  • CNN/DailyMail
  • CoQA
  • QASPER
  • HellaSwag
  • PIQA
  • COPA
  • ARC
  • BoolQ