Use server-side multimodal LLMs to bootstrap federated learning on heterogeneous, long-tailed image data

September 9, 20247 min

Overview

Production Readiness

0.6

Novelty Score

0.6

Cost Impact Score

0.65

Citation Count

1

Authors

Jianyi Zhang, Hao Frank Yang, Ang Li, Xin Guo, Pu Wang, Haiming Wang, Yiran Chen, Hai Li

Links

Abstract / PDF

Why It Matters For Business

You can improve federated accuracy on skewed client data without increasing client compute or sending gradients, lowering device cost and privacy exposure while using server compute and public web data.

Summary TLDR

The paper introduces MLLM-LLaVA-FL, a three-stage federated learning (FL) framework that keeps heavy multimodal LLMs on the server to (1) annotate and pretrain compact FL models on web image-text data, (2) distribute the pretrained model for client-side finetuning, and (3) perform server-side global alignment with class-balanced data. On CIFAR-10/100-LT and ImageNet-LT the method improves top-1 accuracy versus CLIP2FL (e.g., +2.12% on CIFAR-10-LT IF=100, +1.94% on CIFAR-100-LT IF=100, ImageNet-LT overall +1.22% and 'Few' classes +15.29%). The approach avoids extra client compute and avoids uploading client gradients, aiming to reduce privacy risk.

Problem Statement

Federated learning drops performance when clients have different, long-tailed data. Existing fixes either send gradients (privacy risk) or require large models on devices (high compute/memory). The paper asks: can server-side multimodal LLMs use open web image-text data to pretrain and align compact FL models so clients stay light and private while accuracy improves?

Main Contribution

A three-stage FL framework that uses server-side multimodal LLMs for (1) global multimodal pretraining, (2) federated finetuning, and (3) server-side global alignment.

Dynamic Weighted Pretraining: gradually distill features from a large frozen visual encoder into a compact FL model using MLLM-generated web annotations.

A global alignment step that uses a small class-balanced server dataset and a KL+cross-entropy loss to reduce long-tail bias.

Empirical gains over prior FL baselines (including CLIP2FL and CReFF) on CIFAR-10/100-LT and ImageNet-LT while keeping client compute low and avoiding gradient uploads.

Key Findings

MLLM-LLaVA-FL beats CLIP2FL on CIFAR-LT benchmarks

NumbersCIFAR-10-LT IF=100: 75.49% vs 73.37% (+2.12%); CIFAR-100-LT IF=100: 39.50% vs 37.56% (+1.94%)

ImageNet-LT shows notable gains on scarce classes

NumbersImageNet-LT 'Few' classes: 25.58% (MLLM-LLaVA-FL) vs 10.29% (CReFF) (+15.29%); overall +1.22%

Server-side pretraining speeds learning in low-data regimes

NumbersFewer epochs to reach targets in 1%/2% CIFAR subsets; pretrained models hit higher peak accuracies within 30 epochs (see

Results

Accuracy

Value75.49%

BaselineCLIP2FL 73.37%

Accuracy

Value39.50%

BaselineCLIP2FL 37.56%

Accuracy

Value25.58%

BaselineCReFF 10.29%

Accuracy

Value27.53%

BaselineCReFF 26.31%

Who Should Care

What To Try In 7 Days

Run an MLLM (e.g., LLaVA/GPT-4) on a small crawl of public images to produce captions and QA-style annotations.

Implement Dynamic Weighted Pretraining: distill a frozen CLIP encoder into your compact FL model with a rising weight schedule (alpha 0→1).

Replace client-side heavy models with the compact pretrained model and run a quick federated finetune with FedAvg on a small non-iid split to measure accuracy gains.

Optimization Features

Infra Optimization

  • Single A100 80G GPU used in experiments

System Optimization

  • Shift heavy multimodal LLM compute to server to reduce client cost

Training Optimization

  • Dynamic Weighted Pretraining (distill large encoder into compact FL model)
  • Server-side pretraining using MLLM annotations

Reproducibility

Data Urls

  • CC-595K (LLaVA pretraining data)
  • CIFAR-10-LT
  • CIFAR-100-LT
  • ImageNet-LT

Data Available

Open Source Status

  • partial

Risks & Boundaries

Limitations

  • Relies on access to large, legally usable web image-text data and substantial server compute.
  • Experiments limited to image classification long-tailed benchmarks; not validated on other modalities or real-world deployments.
  • Quality of MLLM annotations or hallucinations can affect pretraining quality.

When Not To Use

  • You lack server GPU resources or cannot legally use web-scraped images.
  • Your FL application is non-visual or needs real-time on-device heavy inference.
  • You must avoid any use of external web data for regulatory reasons.

Failure Modes

  • MLLM-generated labels are noisy or incorrect, leading to wrong pretraining signals.
  • Server alignment dataset misses classes, so long-tail correction fails for unseen categories.
  • Model capacity mismatch: compact FL model may not absorb knowledge from large frozen encoders.

Core Entities

Models

  • LLaVA
  • GPT-4
  • CLIP
  • Vicuna
  • LLaMA-2
  • ResNet-8
  • ResNet-50

Metrics

  • Accuracy

Datasets

  • CC-595K (LLaVA pretraining set)
  • CIFAR-10-LT
  • CIFAR-100-LT
  • ImageNet-LT

Benchmarks

  • CIFAR-10-LT
  • CIFAR-100-LT
  • ImageNet-LT

Context Entities

Models

  • CLIP2FL
  • CReFF
  • FedAvg
  • FedAvgM
  • FedProx