Use attention-equipped diffusion models to learn coordinated multi-agent policies and predict joint trajectories from offline logs

May 27, 20237 min

Overview

Decision SnapshotReady For Pilot

Strong empirical results across standard offline MARL and trajectory prediction benchmarks, with ablations and stochasticity tests; constrained to small teams and offline regimes.

Citations5

Evidence Strength0.80

Confidence0.85

Risk Signals9

Trust Signals

Findings with numeric evidence: 3/3

Findings with evidence refs: 3/3

Results with explicit delta: 3/3

Reproducibility

Status: Code + data available

Open source: Partial

At A Glance

Cost impact: 50%

Production readiness: 60%

Novelty: 70%

Authors

Zhengbang Zhu, Minghuan Liu, Liyuan Mao, Bingyi Kang, Minkai Xu, Yong Yu, Stefano Ermon, Weinan Zhang

Links

Abstract / PDF / Code / Data

Why It Matters For Business

MADiff can learn coordinated policies and reliable joint trajectory predictions from logs, enabling product features where online trials are costly or unsafe; it's best for small teams and stable environments.

Who Should Care

Summary TLDR

MADiff is the first diffusion-based framework for offline multi-agent problems. It trains a return-conditioned diffusion model with inter-agent attention so a single learned model can act as (a) a centralized controller that samples joint trajectories, (b) a decentralized policy that predicts teammates and plans locally, and (c) a multi-agent trajectory predictor. Experiments across MPE, MA-Mujoco, SMAC, and an NBA dataset show marked gains in trajectory prediction and consistent improvements or competitive performance on many offline MARL tasks, but the method scales poorly past small teams and loses ground in highly stochastic environments.

Problem Statement

Offline multi-agent learning must learn coordinated policies from static logs. Single-agent diffusion methods do not model multi-agent coordination well; independent models break coordination and full concatenation breaks agent symmetry and is sample-inefficient. The paper asks: can a diffusion model with inter-agent attention learn joint behavior and support both centralized planning and decentralized execution?

Main Contribution

MADiff: a CTDE (centralized training, decentralized execution) diffusion framework for multi-agent problems that also works as a centralized controller and a trajectory predictor.

An attention-based diffusion architecture that inserts multi-agent attention in U-Net decoder layers so agents exchange information during denoising steps.

Key Findings

MADiff greatly improves multi-agent trajectory prediction on the NBA dataset.

NumbersADE 7.92 ± 0.86 vs 15.15 ± 0.38 (Baller2Vec++), traj len 20

Practical UseUse MADIFF-C for stable, more accurate multi-agent trajectory forecasts in sports and similar tracking datasets.

Evidence RefTable 2

MADiff-D (decentralized) achieves the best or competitive returns on most offline MARL benchmarks evaluated.

NumbersOutperforms baselines across many Table 1 tasks (see per-task scores in Table 1)

Practical UseFor offline datasets of small teams (≤8), prefer MADIFF-D to learn coordinated decentralized policies without online interaction.

Evidence RefTable 1

Results

MetricValueBaselineDeltaSplit / DatasetEvidenceEvidence Ref
ADE (NBA dataset, traj len 20)7.92 ± 0.86 (MADIFF-C)15.15 ± 0.38 (Baller2Vec++)-7.23NBA test set, first-step conditionedTable 2 reports ADE/FDE/minADE20/minFDE20 for length 20Table 2
Normalized episodic return (MPE Spread, Expert)116.7 ± 3.0 (MADIFF-C)114.9 ± 2.6 (OMAR)+1.8MPE Spread, Expert datasetTable 1 shows normalized scores for MPE tasksTable 1

What To Try In 7 Days

Run MADIFF-C on one trajectory-prediction dataset and compare ADE/FDE with your current model.

Train MADIFF-D on one offline multi-agent log (≤8 agents) to test decentralized policy quality.

Ablate teammate modeling: mask other agents during training to measure coordination gains quickly.

Agent Features

Memory
short history conditioning (C steps) for planning
Planning
return-conditioned trajectory planninghistory-conditioned planning (optional)
Tool Use
classifier-free guidanceinverse dynamics modelDDIM low-step sampling
Frameworks
diffusion models (denoising reverse process)return-conditioning
Is Agentic

Yes

Architectures
attention-based diffusion modelU-Net per agent (shared or per-agent)multi-head inter-agent attention
Collaboration
teammate modeling (predict other agents)centralized training, decentralized execution (CTDE)centralized controller variant

Optimization Features

Model Optimization
shared U-Net across agents to reduce parametersbatched agent trajectories for GPU efficiency
System Optimization
GPU acceleration (RTX 3090 used) to keep sampling fast
Training Optimization
centralized training on joint trajectoriesclassifier-free guidance training mix (β sampling)
Inference Optimization
low-temperature initial noise scalingDDIM sampling (15 steps used in examples)shared model batching keeps sampling time nearly constant with agent count

Reproducibility

Code AvailableYes
Data AvailableYes
Open Source StatusPartial
LicenseUnknown

Code URLs

supplementary materials (anonymous code and instructions provided)

Data URLs

OMAR fork and off-the-grid datasets referenced in appendix; supplementary materials for dataset links

Risks & Boundaries

Limitations

Scales poorly to many agents: infers all teammates' futures; experiments up to 8 agents.

Performance degrades in highly stochastic environments (SMACv2 original setting).

When Not To Use

Environments with tens or hundreds of agents where inferring all teammates is impractical.

Datasets with high transition stochasticity where Q-learning methods perform better.

Failure Modes

Incorrect teammate predictions can cause coordinated failures.

Model may prefer misleading high-return but rare trajectories in offline logs.

Core Entities

Models

MADIFFattention-based diffusion modelper-agent U-Net backboneinverse dynamics model

Metrics

episodic return (normalized)ADE (average displacement error)FDE (final displacement error)minADE20minFDE20

Datasets

MPE (Spread, Tag, World) - OMAR forkMA-Mujoco (2halfcheetah, 2ant, 4ant) - off-the-gridSMAC (3m, 2s3z, 5m_vs_6m, 8m) - off-the-gridNBA player trajectories (MATP)

Benchmarks

MA-ICQMA-CQLOMARMA-TD3+BCMADTBaller2Vec++