Use attention-equipped diffusion models to learn coordinated multi-agent policies and predict joint trajectories from offline logs

May 27, 20237 min

Overview

Production Readiness

0.6

Novelty Score

0.7

Cost Impact Score

0.5

Citation Count

5

Authors

Zhengbang Zhu, Minghuan Liu, Liyuan Mao, Bingyi Kang, Minkai Xu, Yong Yu, Stefano Ermon, Weinan Zhang

Links

Abstract / PDF

Why It Matters For Business

MADiff can learn coordinated policies and reliable joint trajectory predictions from logs, enabling product features where online trials are costly or unsafe; it's best for small teams and stable environments.

Summary TLDR

MADiff is the first diffusion-based framework for offline multi-agent problems. It trains a return-conditioned diffusion model with inter-agent attention so a single learned model can act as (a) a centralized controller that samples joint trajectories, (b) a decentralized policy that predicts teammates and plans locally, and (c) a multi-agent trajectory predictor. Experiments across MPE, MA-Mujoco, SMAC, and an NBA dataset show marked gains in trajectory prediction and consistent improvements or competitive performance on many offline MARL tasks, but the method scales poorly past small teams and loses ground in highly stochastic environments.

Problem Statement

Offline multi-agent learning must learn coordinated policies from static logs. Single-agent diffusion methods do not model multi-agent coordination well; independent models break coordination and full concatenation breaks agent symmetry and is sample-inefficient. The paper asks: can a diffusion model with inter-agent attention learn joint behavior and support both centralized planning and decentralized execution?

Main Contribution

MADiff: a CTDE (centralized training, decentralized execution) diffusion framework for multi-agent problems that also works as a centralized controller and a trajectory predictor.

An attention-based diffusion architecture that inserts multi-agent attention in U-Net decoder layers so agents exchange information during denoising steps.

Empirical evidence that MADiff improves trajectory prediction and yields strong or competitive performance on diverse offline multi-agent benchmarks.

Key Findings

MADiff greatly improves multi-agent trajectory prediction on the NBA dataset.

NumbersADE 7.92 ± 0.86 vs 15.15 ± 0.38 (Baller2Vec++), traj len 20

MADiff-D (decentralized) achieves the best or competitive returns on most offline MARL benchmarks evaluated.

NumbersOutperforms baselines across many Table 1 tasks (see per-task scores in Table 1)

MADiff is sensitive to environmental stochasticity and can be outperformed by conservative Q-learning in highly stochastic settings.

NumbersSMACv2 terran_5_vs_5 original: MADIFF-D 10.1 ± 0.8 vs MAICQ 13.7 ± 1.7

Results

ADE (NBA dataset, traj len 20)

Value7.92 ± 0.86 (MADIFF-C)

Baseline15.15 ± 0.38 (Baller2Vec++)

Normalized episodic return (MPE Spread, Expert)

Value116.7 ± 3.0 (MADIFF-C)

Baseline114.9 ± 2.6 (OMAR)

Average score (SMACv2 terran_5_vs_5, original)

Value10.1 ± 0.8 (MADIFF-D)

Baseline13.7 ± 1.7 (MAICQ)

Who Should Care

What To Try In 7 Days

Run MADIFF-C on one trajectory-prediction dataset and compare ADE/FDE with your current model.

Train MADIFF-D on one offline multi-agent log (≤8 agents) to test decentralized policy quality.

Ablate teammate modeling: mask other agents during training to measure coordination gains quickly.

Agent Features

Memory

  • short history conditioning (C steps) for planning

Planning

  • return-conditioned trajectory planning
  • history-conditioned planning (optional)

Tool Use

  • classifier-free guidance
  • inverse dynamics model
  • DDIM low-step sampling

Frameworks

  • diffusion models (denoising reverse process)
  • return-conditioning

Is Agentic

true

Architectures

  • attention-based diffusion model
  • U-Net per agent (shared or per-agent)
  • multi-head inter-agent attention

Collaboration

  • teammate modeling (predict other agents)
  • centralized training, decentralized execution (CTDE)
  • centralized controller variant

Optimization Features

Model Optimization

  • shared U-Net across agents to reduce parameters
  • batched agent trajectories for GPU efficiency

System Optimization

  • GPU acceleration (RTX 3090 used) to keep sampling fast

Training Optimization

  • centralized training on joint trajectories
  • classifier-free guidance training mix (β sampling)

Inference Optimization

  • low-temperature initial noise scaling
  • DDIM sampling (15 steps used in examples)
  • shared model batching keeps sampling time nearly constant with agent count

Reproducibility

Code Urls

  • supplementary materials (anonymous code and instructions provided)

Data Urls

  • OMAR fork and off-the-grid datasets referenced in appendix; supplementary materials for dataset links

Code Available

Data Available

Open Source Status

  • partial

Risks & Boundaries

Limitations

  • Scales poorly to many agents: infers all teammates' futures; experiments up to 8 agents.
  • Performance degrades in highly stochastic environments (SMACv2 original setting).
  • Sequence-modeling approach can overfit to high-return trajectories achieved by chance in offline data.

When Not To Use

  • Environments with tens or hundreds of agents where inferring all teammates is impractical.
  • Datasets with high transition stochasticity where Q-learning methods perform better.
  • Real-time systems with severe latency constraints if many diffusion steps are required.

Failure Modes

  • Incorrect teammate predictions can cause coordinated failures.
  • Model may prefer misleading high-return but rare trajectories in offline logs.
  • Fixed shared parameterization may underperform when agents are highly heterogeneous.

Core Entities

Models

  • MADIFF
  • attention-based diffusion model
  • per-agent U-Net backbone
  • inverse dynamics model

Metrics

  • episodic return (normalized)
  • ADE (average displacement error)
  • FDE (final displacement error)
  • minADE20
  • minFDE20

Datasets

  • MPE (Spread, Tag, World) - OMAR fork
  • MA-Mujoco (2halfcheetah, 2ant, 4ant) - off-the-grid
  • SMAC (3m, 2s3z, 5m_vs_6m, 8m) - off-the-grid
  • NBA player trajectories (MATP)

Benchmarks

  • MA-ICQ
  • MA-CQL
  • OMAR
  • MA-TD3+BC
  • MADT
  • Baller2Vec++