Skip to main content

Distributed Training Overview

Scale your Flax NNX models across multiple devices with JAX's powerful parallelism primitives. This guide covers everything from simple data parallelism to advanced sharding strategies for training models at any scale.

Quick Decision Guide​

Choose your parallelism strategy:

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Does your model fit on a single device? β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β”Œβ”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”
β”‚ β”‚
YES NO
β”‚ β”‚
β–Ό β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Data β”‚ β”‚ Need sequential β”‚
β”‚Parallelβ”‚ β”‚ architecture? β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β”Œβ”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”
β”‚ β”‚
YES NO
β”‚ β”‚
β–Ό β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚Pipeline β”‚ β”‚FSDP or β”‚
β”‚Parallel β”‚ β”‚Tensor β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚Parallel β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Parallelism Strategies​

1. Data Parallelism​

Replicate model, split data

  • βœ… Simplest to implement
  • βœ… Perfect scaling for throughput
  • βœ… No model changes needed
  • ❌ Model must fit on single device
@jax.pmap
def train_step(state, batch):
# Each device processes different data
grads = compute_gradients(state, batch)
grads = jax.lax.pmean(grads, 'devices') # Sync
return state.apply_gradients(grads)

πŸ“š Learn Data Parallelism β†’

2. SPMD / Automatic Sharding​

Flexible sharding with compiler optimization

  • βœ… Very flexible (any sharding pattern)
  • βœ… Modern JAX best practice
  • βœ… Automatic optimization
  • ⚠️ Requires understanding sharding
mesh = Mesh(devices, axis_names=('data', 'model'))
sharding = NamedSharding(mesh, P('data', 'model'))

@jax.jit
def train_step(state, batch):
# Compiler handles communication automatically
return update(state, batch)

πŸ“š Learn SPMD Sharding β†’

3. Pipeline Parallelism​

Split model into stages

  • βœ… Train very large models
  • βœ… Works with sequential architectures
  • ❌ Pipeline bubbles (70-90% efficiency)
  • ❌ Complex implementation
# Stage 1 on Device 0
# Stage 2 on Device 1
# Stage 3 on Device 2
# Stage 4 on Device 3

# Microbatches flow through pipeline

πŸ“š Learn Pipeline Parallelism β†’

4. FSDP (Fully Sharded Data Parallel)​

Shard everything to save memory

  • βœ… Massive memory savings (NΓ— reduction)
  • βœ… Train NΓ— larger models
  • ❌ More communication overhead
  • ⚠️ Needs fast interconnect
# Shard parameters across all devices
mesh = Mesh(devices, axis_names=('fsdp',))
params = shard_fsdp(params, mesh)

# Automatic all-gather and reduce-scatter

πŸ“š Learn FSDP β†’

Strategy Comparison​

StrategyMemory/DeviceThroughputCommunicationBest For
Data ParallelFull model (P)ExcellentO(P) once/stepStandard training
SPMDConfigurableExcellentOptimizedFlexible needs
PipelineP/NGood (70-90%)O(activations)Very large models
FSDPP/NGoodO(2P) per layerMemory constrained

P = Model size, N = Number of devices

Real-World Examples​

Example 1: ResNet-50 Training (25M params)​

Model fits easily on single GPU

# βœ… Best: Data Parallelism
# - Simple implementation
# - Perfect scaling
# - No memory concerns

# Configuration:
# - 8Γ— A100 GPUs
# - Batch size: 32/device = 256 total
# - Training time: 100% efficient

Example 2: GPT-2 Medium (355M params)​

Model fits but tight on memory

# βœ… Best: SPMD with data parallelism
# - Flexible for future growth
# - Modern approach
# - Can add model parallelism if needed

# Configuration:
# - 8Γ— A100-40GB
# - Pure data parallel: P('data', None)
# - Or light tensor parallel: P('data', 'model') with mesh (4, 2)

Example 3: GPT-3 Scale (175B params)​

Model way too large for single device

# βœ… Best: Combination strategy
# - FSDP for memory: 1024 devices
# - Pipeline: 8 stages
# - Tensor parallel: 4-way per stage

# Configuration:
# - 1024Γ— A100-80GB
# - Mesh: (8, 4, 32) = (pipeline, tensor, fsdp)
# - Per device: ~600MB

Example 4: Vision Transformer (1B params)​

Sequential architecture, moderately large

# βœ… Good options:
# Option A: FSDP (if 16+ GPUs)
# - Memory: 1GB per device (16 GPUs)
# - Clean implementation

# Option B: Pipeline (if 4-8 GPUs)
# - 4 stages, 8 microbatches
# - Efficiency: 73%

# Configuration:
# - 8Γ— A100-40GB
# - Choose based on interconnect speed

Combining Strategies​

Many large-scale training runs combine multiple strategies:

FSDP + Data Parallelism​

# Shard model parameters (FSDP)
# Each shard replica uses data parallelism

mesh = Mesh(devices, axis_names=('fsdp', 'data'))
# Shape: (8, 16) = 128 devices total
# 8-way FSDP, 16-way data parallel per shard

Pipeline + Tensor Parallelism​

# Split model into pipeline stages
# Each stage uses tensor parallelism

# Stage 1: Devices 0-7 (8-way tensor parallel)
# Stage 2: Devices 8-15 (8-way tensor parallel)
# Stage 3: Devices 16-23 (8-way tensor parallel)
# Stage 4: Devices 24-31 (8-way tensor parallel)

3D Parallelism (Pipeline + Tensor + Data)​

# The ultimate combination for massive models

mesh = Mesh(devices, axis_names=('pipeline', 'tensor', 'data'))
# Shape: (8, 8, 16) = 1024 devices
# - 8 pipeline stages
# - 8-way tensor parallel per stage
# - 16-way data parallel

# Used by: GPT-3, PaLM, LLaMA-2

Getting Started​

Step 1: Profile Your Model​

import jax
import jax.numpy as jnp
from flax import nnx

# Initialize model
model = YourModel(...)

# Check size
graphdef, params = nnx.split(model)
total_params = sum(p.size for p in jax.tree.leaves(params))
model_size_gb = total_params * 4 / 1e9 # float32

print(f"Model: {total_params/1e9:.2f}B parameters ({model_size_gb:.2f} GB)")

# Profile one training step
@jax.jit
def profile_step(state, batch):
# Your training step
pass

# Run once to compile
state, metrics = profile_step(state, batch)

# Time it
import time
start = time.time()
for _ in range(10):
state, metrics = profile_step(state, batch)
elapsed = (time.time() - start) / 10

print(f"Step time: {elapsed*1000:.1f}ms")

Step 2: Choose Strategy​

Use the decision guide above based on:

  • Model size vs device memory
  • Number of available devices
  • Interconnect speed
  • Architecture (sequential or not)

Step 3: Implement​

Follow the detailed guides for your chosen strategy:

Step 4: Optimize​

After basic implementation works:

  1. Profile with jax.profiler.trace()
  2. Check device utilization (should be >80%)
  3. Adjust batch size (larger = better efficiency)
  4. Enable mixed precision (bfloat16)
  5. Tune communication (see Best Practices)

Common Pitfalls​

❌ Wrong: Using pmap for large models​

# Model: 10B parameters, won't fit on single GPU!
@jax.pmap
def train_step(state, batch):
# Each device needs full 10B model = OOM!
pass

βœ… Use FSDP or Pipeline Parallelism instead

❌ Wrong: Too few microbatches with pipeline​

# 4 pipeline stages, only 2 microbatches
# Efficiency: 2/(2+4-1) = 40% (terrible!)

βœ… Use 4Γ— stages microbatches minimum (16 for 4 stages)

❌ Wrong: FSDP with slow interconnect​

# Using FSDP over 10Gb Ethernet
# Communication time > compute time!

βœ… FSDP needs NVLink or InfiniBand (100+ GB/s)

❌ Wrong: Not accounting for optimizer state​

# Model: 10GB
# Fits on A100-40GB?
# NO! Adam needs: 10GB params + 10GB grads + 20GB optimizer = 40GB

βœ… Budget 4Γ— model size for training (params + grads + optimizer)

Scaling Laws​

Data Parallelism Scaling​

Throughput = single_device_throughput Γ— num_devices Γ— efficiency

# Efficiency typically:
# - 2-4 devices: 95-98%
# - 8 devices: 90-95%
# - 16+ devices: 85-92%
# (Depends on model size and interconnect)

FSDP Scaling​

Max_model_size = device_memory Γ— num_devices / 4

# Examples (A100-40GB):
# - 8 devices: 80GB model (20B params)
# - 64 devices: 640GB model (160B params)
# - 1024 devices: 10TB model (2.5T params)

Pipeline Efficiency​

Efficiency = M / (M + S - 1)
where M = microbatches, S = stages

# To reach 90% efficiency:
M β‰₯ 9 Γ— S - 9

# Examples:
# 4 stages: Need 27+ microbatches for 90%
# 8 stages: Need 63+ microbatches for 90%

Hardware Considerations​

Interconnect Speed​

InterconnectBandwidthGood For
NVLink (V100)300 GB/sAll strategies βœ…
NVLink (A100)600 GB/sAll strategies βœ…βœ…
NVLink (H100)900 GB/sAll strategies βœ…βœ…βœ…
PCIe 4.064 GB/sData parallel only
10Gb Ethernet1.25 GB/sSingle device only
InfiniBand200 GB/sAll strategies βœ…βœ…

Memory Hierarchy​

Device Memory (fast, small):
β”œβ”€ L2 Cache: ~40-80MB (fastest)
β”œβ”€ HBM: 40-80GB (fast)
└─ When full β†’ OOM!

Host Memory (slow, large):
└─ RAM: 100s of GB (for data loading)

Storage (slowest, largest):
└─ Disk: TBs (for dataset)

Key insight: Training happens in device memory. Must fit model + gradients + optimizer + activations.

Monitoring Training​

Essential Metrics​

# 1. Loss / Accuracy (correctness)
# 2. Step time (efficiency)
# 3. Device utilization (>80% ideal)
# 4. Memory usage (should be high but not OOM)
# 5. Communication time (should be <30% of step time)

# Log these every N steps:
if step % 100 == 0:
metrics = {
'loss': float(loss),
'accuracy': float(accuracy),
'step_time_ms': step_time * 1000,
'device_utilization': utilization,
}
print(metrics)

Profiling​

# Profile to find bottlenecks
from jax import profiler

with profiler.trace("/tmp/jax-trace"):
for _ in range(10):
state = train_step(state, batch)

# View in TensorBoard:
# tensorboard --logdir=/tmp/jax-trace

# Look for:
# - Computation time (should be high)
# - Communication time (minimize)
# - Idle time (minimize)

Next Steps​

  1. Start simple: Data Parallelism Guide
  2. Go modern: SPMD Sharding Guide
  3. Scale up: FSDP Guide or Pipeline Guide

Example Code​

Check out our complete, runnable examples:

  • examples/16_data_parallel_pmap.py - Data parallelism with pmap
  • examples/17_sharding_spmd.py - SPMD automatic sharding
  • examples/18_pipeline_parallelism.py - Pipeline parallelism
  • examples/19_fsdp_sharding.py - FSDP fully sharded training

Each example is self-contained and includes detailed comments explaining what's happening under the hood.

Further Reading​