Distributed Training Overview
Scale your Flax NNX models across multiple devices with JAX's powerful parallelism primitives. This guide covers everything from simple data parallelism to advanced sharding strategies for training models at any scale.
Quick Decision Guideβ
Choose your parallelism strategy:
βββββββββββββββββββββββββββββββββββββββββββ
β Does your model fit on a single device? β
ββββββββββββββββ¬βββββββββββββββββββββββββββ
β
ββββββββ΄βββββββ
β β
YES NO
β β
βΌ βΌ
βββββββββ ββββββββββββββββββββ
β Data β β Need sequential β
βParallelβ β architecture? β
βββββββββ βββββββββββ¬βββββββββ
β
ββββββββ΄βββββββ
β β
YES NO
β β
βΌ βΌ
ββββββββββββ ββββββββββββ
βPipeline β βFSDP or β
βParallel β βTensor β
ββββββββββββ βParallel β
ββββββββββββ
Parallelism Strategiesβ
1. Data Parallelismβ
Replicate model, split data
- β Simplest to implement
- β Perfect scaling for throughput
- β No model changes needed
- β Model must fit on single device
@jax.pmap
def train_step(state, batch):
# Each device processes different data
grads = compute_gradients(state, batch)
grads = jax.lax.pmean(grads, 'devices') # Sync
return state.apply_gradients(grads)
π Learn Data Parallelism β
2. SPMD / Automatic Shardingβ
Flexible sharding with compiler optimization
- β Very flexible (any sharding pattern)
- β Modern JAX best practice
- β Automatic optimization
- β οΈ Requires understanding sharding
mesh = Mesh(devices, axis_names=('data', 'model'))
sharding = NamedSharding(mesh, P('data', 'model'))
@jax.jit
def train_step(state, batch):
# Compiler handles communication automatically
return update(state, batch)
3. Pipeline Parallelismβ
Split model into stages
- β Train very large models
- β Works with sequential architectures
- β Pipeline bubbles (70-90% efficiency)
- β Complex implementation
# Stage 1 on Device 0
# Stage 2 on Device 1
# Stage 3 on Device 2
# Stage 4 on Device 3
# Microbatches flow through pipeline
π Learn Pipeline Parallelism β
4. FSDP (Fully Sharded Data Parallel)β
Shard everything to save memory
- β Massive memory savings (NΓ reduction)
- β Train NΓ larger models
- β More communication overhead
- β οΈ Needs fast interconnect
# Shard parameters across all devices
mesh = Mesh(devices, axis_names=('fsdp',))
params = shard_fsdp(params, mesh)
# Automatic all-gather and reduce-scatter
π Learn FSDP β
Strategy Comparisonβ
| Strategy | Memory/Device | Throughput | Communication | Best For |
|---|---|---|---|---|
| Data Parallel | Full model (P) | Excellent | O(P) once/step | Standard training |
| SPMD | Configurable | Excellent | Optimized | Flexible needs |
| Pipeline | P/N | Good (70-90%) | O(activations) | Very large models |
| FSDP | P/N | Good | O(2P) per layer | Memory constrained |
P = Model size, N = Number of devices
Real-World Examplesβ
Example 1: ResNet-50 Training (25M params)β
Model fits easily on single GPU
# β
Best: Data Parallelism
# - Simple implementation
# - Perfect scaling
# - No memory concerns
# Configuration:
# - 8Γ A100 GPUs
# - Batch size: 32/device = 256 total
# - Training time: 100% efficient
Example 2: GPT-2 Medium (355M params)β
Model fits but tight on memory
# β
Best: SPMD with data parallelism
# - Flexible for future growth
# - Modern approach
# - Can add model parallelism if needed
# Configuration:
# - 8Γ A100-40GB
# - Pure data parallel: P('data', None)
# - Or light tensor parallel: P('data', 'model') with mesh (4, 2)
Example 3: GPT-3 Scale (175B params)β
Model way too large for single device
# β
Best: Combination strategy
# - FSDP for memory: 1024 devices
# - Pipeline: 8 stages
# - Tensor parallel: 4-way per stage
# Configuration:
# - 1024Γ A100-80GB
# - Mesh: (8, 4, 32) = (pipeline, tensor, fsdp)
# - Per device: ~600MB
Example 4: Vision Transformer (1B params)β
Sequential architecture, moderately large
# β
Good options:
# Option A: FSDP (if 16+ GPUs)
# - Memory: 1GB per device (16 GPUs)
# - Clean implementation
# Option B: Pipeline (if 4-8 GPUs)
# - 4 stages, 8 microbatches
# - Efficiency: 73%
# Configuration:
# - 8Γ A100-40GB
# - Choose based on interconnect speed
Combining Strategiesβ
Many large-scale training runs combine multiple strategies:
FSDP + Data Parallelismβ
# Shard model parameters (FSDP)
# Each shard replica uses data parallelism
mesh = Mesh(devices, axis_names=('fsdp', 'data'))
# Shape: (8, 16) = 128 devices total
# 8-way FSDP, 16-way data parallel per shard
Pipeline + Tensor Parallelismβ
# Split model into pipeline stages
# Each stage uses tensor parallelism
# Stage 1: Devices 0-7 (8-way tensor parallel)
# Stage 2: Devices 8-15 (8-way tensor parallel)
# Stage 3: Devices 16-23 (8-way tensor parallel)
# Stage 4: Devices 24-31 (8-way tensor parallel)
3D Parallelism (Pipeline + Tensor + Data)β
# The ultimate combination for massive models
mesh = Mesh(devices, axis_names=('pipeline', 'tensor', 'data'))
# Shape: (8, 8, 16) = 1024 devices
# - 8 pipeline stages
# - 8-way tensor parallel per stage
# - 16-way data parallel
# Used by: GPT-3, PaLM, LLaMA-2
Getting Startedβ
Step 1: Profile Your Modelβ
import jax
import jax.numpy as jnp
from flax import nnx
# Initialize model
model = YourModel(...)
# Check size
graphdef, params = nnx.split(model)
total_params = sum(p.size for p in jax.tree.leaves(params))
model_size_gb = total_params * 4 / 1e9 # float32
print(f"Model: {total_params/1e9:.2f}B parameters ({model_size_gb:.2f} GB)")
# Profile one training step
@jax.jit
def profile_step(state, batch):
# Your training step
pass
# Run once to compile
state, metrics = profile_step(state, batch)
# Time it
import time
start = time.time()
for _ in range(10):
state, metrics = profile_step(state, batch)
elapsed = (time.time() - start) / 10
print(f"Step time: {elapsed*1000:.1f}ms")
Step 2: Choose Strategyβ
Use the decision guide above based on:
- Model size vs device memory
- Number of available devices
- Interconnect speed
- Architecture (sequential or not)
Step 3: Implementβ
Follow the detailed guides for your chosen strategy:
- π Starting out? β Data Parallelism
- π― Want flexibility? β SPMD Sharding
- πΎ Need memory? β FSDP
- ποΈ Very large model? β Pipeline Parallelism
Step 4: Optimizeβ
After basic implementation works:
- Profile with
jax.profiler.trace() - Check device utilization (should be >80%)
- Adjust batch size (larger = better efficiency)
- Enable mixed precision (bfloat16)
- Tune communication (see Best Practices)
Common Pitfallsβ
β Wrong: Using pmap for large modelsβ
# Model: 10B parameters, won't fit on single GPU!
@jax.pmap
def train_step(state, batch):
# Each device needs full 10B model = OOM!
pass
β Use FSDP or Pipeline Parallelism instead
β Wrong: Too few microbatches with pipelineβ
# 4 pipeline stages, only 2 microbatches
# Efficiency: 2/(2+4-1) = 40% (terrible!)
β Use 4Γ stages microbatches minimum (16 for 4 stages)
β Wrong: FSDP with slow interconnectβ
# Using FSDP over 10Gb Ethernet
# Communication time > compute time!
β FSDP needs NVLink or InfiniBand (100+ GB/s)
β Wrong: Not accounting for optimizer stateβ
# Model: 10GB
# Fits on A100-40GB?
# NO! Adam needs: 10GB params + 10GB grads + 20GB optimizer = 40GB
β Budget 4Γ model size for training (params + grads + optimizer)
Scaling Lawsβ
Data Parallelism Scalingβ
Throughput = single_device_throughput Γ num_devices Γ efficiency
# Efficiency typically:
# - 2-4 devices: 95-98%
# - 8 devices: 90-95%
# - 16+ devices: 85-92%
# (Depends on model size and interconnect)
FSDP Scalingβ
Max_model_size = device_memory Γ num_devices / 4
# Examples (A100-40GB):
# - 8 devices: 80GB model (20B params)
# - 64 devices: 640GB model (160B params)
# - 1024 devices: 10TB model (2.5T params)
Pipeline Efficiencyβ
Efficiency = M / (M + S - 1)
where M = microbatches, S = stages
# To reach 90% efficiency:
M β₯ 9 Γ S - 9
# Examples:
# 4 stages: Need 27+ microbatches for 90%
# 8 stages: Need 63+ microbatches for 90%
Hardware Considerationsβ
Interconnect Speedβ
| Interconnect | Bandwidth | Good For |
|---|---|---|
| NVLink (V100) | 300 GB/s | All strategies β |
| NVLink (A100) | 600 GB/s | All strategies β β |
| NVLink (H100) | 900 GB/s | All strategies β β β |
| PCIe 4.0 | 64 GB/s | Data parallel only |
| 10Gb Ethernet | 1.25 GB/s | Single device only |
| InfiniBand | 200 GB/s | All strategies β β |
Memory Hierarchyβ
Device Memory (fast, small):
ββ L2 Cache: ~40-80MB (fastest)
ββ HBM: 40-80GB (fast)
ββ When full β OOM!
Host Memory (slow, large):
ββ RAM: 100s of GB (for data loading)
Storage (slowest, largest):
ββ Disk: TBs (for dataset)
Key insight: Training happens in device memory. Must fit model + gradients + optimizer + activations.
Monitoring Trainingβ
Essential Metricsβ
# 1. Loss / Accuracy (correctness)
# 2. Step time (efficiency)
# 3. Device utilization (>80% ideal)
# 4. Memory usage (should be high but not OOM)
# 5. Communication time (should be <30% of step time)
# Log these every N steps:
if step % 100 == 0:
metrics = {
'loss': float(loss),
'accuracy': float(accuracy),
'step_time_ms': step_time * 1000,
'device_utilization': utilization,
}
print(metrics)
Profilingβ
# Profile to find bottlenecks
from jax import profiler
with profiler.trace("/tmp/jax-trace"):
for _ in range(10):
state = train_step(state, batch)
# View in TensorBoard:
# tensorboard --logdir=/tmp/jax-trace
# Look for:
# - Computation time (should be high)
# - Communication time (minimize)
# - Idle time (minimize)
Next Stepsβ
- Start simple: Data Parallelism Guide
- Go modern: SPMD Sharding Guide
- Scale up: FSDP Guide or Pipeline Guide
Example Codeβ
Check out our complete, runnable examples:
examples/16_data_parallel_pmap.py- Data parallelism with pmapexamples/17_sharding_spmd.py- SPMD automatic shardingexamples/18_pipeline_parallelism.py- Pipeline parallelismexamples/19_fsdp_sharding.py- FSDP fully sharded training
Each example is self-contained and includes detailed comments explaining what's happening under the hood.
Further Readingβ
- JAX Documentation on Parallelism
- Google Cloud TPU Guide
- Megatron-LM Paper (Pipeline + Tensor Parallelism)
- ZeRO Paper (FSDP inspiration)
- GPipe Paper (Pipeline Parallelism)