Block-wise Adam that lets you full-finetune 8B+ LLMs on a single 24GB GPU

April 3, 20248 min

Overview

Production Readiness

0.7

Novelty Score

0.45

Cost Impact Score

0.65

Citation Count

1

Authors

Qijun Luo, Hengxu Yu, Xiao Li

Links

Abstract / PDF

Why It Matters For Business

BAdam lets teams do full-parameter finetuning of 8B+ LLMs on single 24GB GPUs, cutting infrastructure cost and widening access to higher-quality fine-tuned models.

Summary TLDR

BAdam is a block-coordinate-descent optimizer that runs Adam inside blocks so you only keep full-precision optimizer state for the active block. It cuts gradient/optimizer memory dramatically, enabling full-parameter finetuning of Llama-scale models on much smaller GPUs. Experiments show BAdam reduces memory and backward time vs Adam and LoRA, achieves similar or better downstream scores (MT-bench, math benchmarks), and keeps high-rank model updates. Code is on GitHub for PyTorch integration.

Problem Statement

Full-parameter finetuning with Adam needs large GPU RAM (roughly 18× model size in GB for optimizer states), so practitioners with limited GPUs must choose low-rank PEFTs like LoRA that can limit performance. The paper asks: can we do true full-parameter finetuning with much less optimizer/gradient memory?

Main Contribution

BAdam: a block coordinate descent (BCD) optimizer that runs K Adam steps on one parameter block at a time and clears optimizer states per block to save memory.

Memory and BP-time analysis showing BAdam stores FP32 optimizer state only for the active block and reduces backward computation for module-wise partitions.

Extensive experiments (Llama 2-7B, Llama 3-8B, Llama 3-70B and RoBERTa-large) showing large memory and time savings, comparable or better downstream performance vs LoRA, LOMO, Galore and sometimes Adam.

Key Findings

BAdam reduces total GPU memory needed to finetune Llama 3-8B to ~23.5GB vs ~144.8GB+ for Adam.

Numbers23.5GB (BAdam) vs 144.8GB+ (Adam); Table 2

BAdam cuts backward pass time roughly in half vs LoRA/LOMO for Llama 3-8B.

NumbersBackward per epoch: 1.74h (BAdam) vs 3.20h (LoRA) and 3.70h (LOMO); Table 3

Instruction-following performance (MT-bench) of BAdam is on par or better than LoRA and often matches Adam.

NumbersLlama 3-8B average MT-bench: 6.30 (BAdam) vs 6.27 (LoRA); Llama 2-7B avg: 4.96 (BAdam) vs 4.54 (LoRA); Table 5

On math benchmarks BAdam matches or beats baselines including Adam and LoRA.

NumbersLlama 3-8B avg: 44.4 (BAdam) vs 44.1 (Adam) and 43.3 (LoRA); Llama 3-70B avg: 62.7 (BAdam) vs 59.0 (LoRA); Table 6

BAdam's learned parameter updates are high-rank and similar to Adam's, not low-rank like LoRA.

NumbersEffective rank and cumulative explained variance plots show BAdam ≈ Adam across layers; Figure 2

Results

Total GPU memory for Llama 3-8B finetuning

Value23.5GB (BAdam)

Baseline144.8GB+ (Adam)

Backward time per epoch

Value1.74 hours (BAdam)

Baseline3.20 hours (LoRA)

MT-bench average (Llama 3-8B)

Value6.30 (BAdam, lr=1e-5 average)

Baseline6.27 (LoRA)

Math benchmarks average (Llama 3-8B)

Value44.4 (BAdam)

Baseline44.1 (Adam), 43.3 (LoRA)

Effective rank of learned updates

ValueHigh (BAdam ≈ Adam)

BaselineLow-rank (LoRA)

Who Should Care

What To Try In 7 Days

Clone the BAdam repo and run the provided Llama-3-8B Alpaca-GPT4 script on a 24GB GPU to validate memory/time claims.

Replace LoRA in one instruction-tuning pipeline with BAdam and compare MT-bench and training time.

If you use mixed precision, test the consecutive-module block partition to speed backward passes.

Optimization Features

Infra Optimization

  • Enables single-GPU finetuning of 8B models (24GB)
  • Reduces optimizer/gradient memory vs Adam; avoids expensive CPU/GPU offload

Model Optimization

  • Full-parameter updates preserved (no low-rank constraint)
  • Learned updates retain high effective rank similar to Adam

System Optimization

  • Mixed precision: global FP16 weights, FP32 for active block
  • Supports gradient accumulation
  • Module-based partition reduces backward computation for shallow blocks

Training Optimization

  • Block coordinate descent: update one block at a time (D blocks)
  • Run K inner Adam steps per active block (K is a new hyperparameter)
  • Clears optimizer states after each block to save memory

Reproducibility

Data Urls

  • Alpaca-GPT4 (public dataset)
  • MathInstruct (public dataset)
  • StarCoder-Python (public dataset)
  • SuperGLUE (public dataset)

Code Available

Data Available

Open Source Status

  • partial

Risks & Boundaries

Limitations

  • Theory covers deterministic gradients; stochastic convergence left for future work.
  • Best gains rely on module-consecutive block partition; arbitrary partitions may help less.
  • Focused on supervised finetuning; RLHF and preference tuning were not evaluated.
  • CPT (continue pretraining) experiments are preliminary and limited to one epoch.

When Not To Use

  • If you already have large multi-GPU memory and standard Adam fits, the extra code complexity may not be worth it.
  • When your training regime requires optimizer state continuity across all parameters (e.g., some specialized adaptive schemes).
  • If your block partitioning cannot align with module boundaries and yields little BP savings.

Failure Modes

  • Choosing K too large may over-optimize a block and harm generalization or slow whole-model progress.
  • Poor block ordering may slow early convergence (ordering matters empirically).
  • Misconfiguring mixed-precision conversions for active-block FP32 copies can introduce numerical issues.
  • Clearing optimizer state per block removes cross-block momentum information, possibly destabilizing some tasks.

Core Entities

Models

  • Llama 3-8B
  • Llama 3-70B
  • Llama 2-7B
  • Llama 3.1-8B-Instruct
  • RoBERTa-large

Metrics

  • MT-bench score
  • Accuracy
  • GPU memory (GB)
  • wall-clock time per epoch (hours)

Datasets

  • Alpaca-GPT4
  • MathInstruct
  • StarCoder-Python
  • SuperGLUE
  • GSM8K
  • MATH
  • Aqua
  • MMLU-Math
  • SAT-Math
  • NumGLUE

Benchmarks

  • MT-bench
  • GSM8K
  • Aqua
  • MMLU-Math
  • SAT-Math
  • MATH
  • NumGLUE
  • SuperGLUE