Skip to main content

Training Workflows

Learn the practical skills to train neural networks - from simple training loops to data loading and optimization.

What You'll Learn

This section covers the essential workflows for training models:

Simple Training Loop - Start here!
Write your first complete training loop from scratch. Learn gradient computation, parameter updates, and validation.

Data Loading
Load and preprocess data efficiently. Learn batching, shuffling, and building data pipelines with TensorFlow Datasets.

Streaming Large Datasets
Train on datasets larger than memory. Learn to stream data on-demand from HuggingFace and handle massive datasets.

Experiment Tracking
Track experiments with Weights & Biases. Monitor training, compare runs, and ensure reproducibility.

Model Export
Export trained models to SafeTensors, ONNX, and HuggingFace Hub for deployment and sharing.

Why This Section?

You can define perfect models, but if you can't train them properly, they're useless. This section teaches you:

  • How to compute gradients and update parameters
  • How to load data without bottlenecking your GPU
  • How to validate and avoid overfitting
  • How to make training fast with JIT compilation

Prerequisites

Before starting:

The Training Workflow

Every training project follows this pattern:

1. Load data → 2. Define model → 3. Train → 4. Evaluate → 5. Save

This section focuses on steps 1, 3, and 4.

Quick Example

Here's the complete workflow you'll master:

from flax import nnx
import optax

# 1. Load data
train_loader = create_data_pipeline(split='train', batch_size=128)
val_loader = create_data_pipeline(split='test', batch_size=128)

# 2. Create model and optimizer
model = MyModel(rngs=nnx.Rngs(params=0))
optimizer = nnx.Optimizer(model, optax.adam(3e-4))

# 3. Training loop
@nnx.jit
def train_step(model, optimizer, batch):
def loss_fn(model):
images, labels = batch
logits = model(images)
return optax.softmax_cross_entropy(logits, labels).mean()

loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads)
return loss

# 4. Train and evaluate
for epoch in range(10):
for batch in train_loader:
loss = train_step(model, optimizer, batch)

accuracy = evaluate(model, val_loader)
print(f"Epoch {epoch}: Accuracy = {accuracy:.2%}")

# 5. Save (see checkpointing guide)

Just 20 lines of actual code!

Common Challenges

Challenge 1: Slow Training

Solution: Use JIT compilation (@nnx.jit), larger batch sizes, and efficient data loading

Challenge 2: GPU Sitting Idle

Solution: Proper data pipeline with prefetching and parallel preprocessing

Challenge 3: Overfitting

Solution: Validation sets, early stopping, and regularization

Challenge 4: NaN Losses

Solution: Lower learning rate, gradient clipping, check data normalization

Training Best Practices

Always validate - Catch overfitting early
Use JIT - 10-100x speedup
Shuffle data - Each epoch in random order
Normalize inputs - Scale to [0,1] or [-1,1]
Monitor metrics - Print loss, accuracy each epoch
Save checkpoints - Don't lose trained models

What's Next?

After mastering basic training:

Complete Examples