Skip to main content

Experiment Tracking and Observability

Learn how to track experiments, monitor training, and debug models using Weights & Biases (W&B). Understand why observability is critical for machine learning research and production.

Why Observability Matters

Training neural networks is empirical science - you need data to understand what's happening:

Problems Without Observability

"My model isn't learning"

  • Is the loss decreasing?
  • Are gradients flowing?
  • Is the learning rate appropriate?
  • → Without metrics, you're flying blind

"Which hyperparameters work best?"

  • Tried 10 learning rates, which was best?
  • Can't remember what you ran yesterday
  • → Lost experiments waste time and compute

"Model works on my laptop, fails on cluster"

  • Different batch sizes? Optimizers? Data ordering?
  • Can't reproduce results
  • → Unreproducible research is worthless

What Good Observability Provides

Real-time monitoring: See training progress live
Experiment comparison: Compare 100 runs at once
Reproducibility: Track every hyperparameter
Debugging: Diagnose issues with visualizations
Collaboration: Share results with team
Publication: Document experiments for papers

Weights & Biases Overview

W&B is the standard for ML experiment tracking. It provides:

  • Automatic logging: Capture metrics with minimal code
  • Dashboard: Beautiful visualizations
  • Artifacts: Version datasets, models, predictions
  • Sweeps: Automated hyperparameter search
  • Reports: Collaborative experiment documentation

Installation and Setup

pip install wandb

Login (one-time):

import wandb

# Get API key from wandb.ai
wandb.login()

Basic Experiment Tracking

Minimal Working Example

import wandb
from flax import nnx
import optax

# 1. Initialize tracking
wandb.init(
project="mnist-classification",
name="baseline-run",
config={
"learning_rate": 1e-3,
"batch_size": 128,
"epochs": 10,
"architecture": "CNN",
}
)

# 2. Create model
model = CNN(rngs=nnx.Rngs(params=0))
optimizer = nnx.Optimizer(model, optax.adam(wandb.config.learning_rate))

# 3. Training loop
for epoch in range(wandb.config.epochs):
for batch in train_loader:
loss, metrics = train_step(model, optimizer, batch)

# 4. Log metrics
wandb.log({
"train/loss": loss,
"train/accuracy": metrics['accuracy'],
})

# Validation
val_loss, val_acc = evaluate(model, val_loader)
wandb.log({
"val/loss": val_loss,
"val/accuracy": val_acc,
"epoch": epoch
})

# 5. Finish tracking
wandb.finish()

Result: Automatic dashboard with loss/accuracy curves, system metrics, and hyperparameters.

Understanding the API

wandb.init(): Starts a new run

  • project: Group related experiments
  • name: Human-readable run identifier
  • config: Hyperparameters to track

wandb.log(): Log metrics

  • Call after each training step
  • Metrics grouped by prefix (train/, val/)
  • Automatically plots time series

wandb.finish(): Mark run complete

  • Uploads final data
  • Releases resources
  • Always call at end!

What to Track

Essential Metrics

# During training step
wandb.log({
# Loss values
"train/loss": loss,
"train/perplexity": jnp.exp(loss), # For language models

# Optimization info
"train/learning_rate": current_lr,
"train/gradient_norm": grad_norm,

# Performance
"train/accuracy": accuracy,
"train/tokens_per_second": throughput,

# Step tracking
"step": global_step,
})

# After each epoch
wandb.log({
# Validation metrics
"val/loss": val_loss,
"val/accuracy": val_acc,

# Best model tracking
"val/best_accuracy": best_acc,

# Epoch info
"epoch": epoch,
"epoch_time": epoch_duration,
})

Gradient Statistics

Monitor gradient health:

def log_gradient_stats(grads):
"""Log gradient statistics for debugging"""

# Flatten all gradients
flat_grads = jax.tree_util.tree_leaves(grads)

# Compute statistics
grad_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in flat_grads))
grad_max = max(jnp.max(jnp.abs(g)) for g in flat_grads)
grad_mean = jnp.mean(jnp.array([jnp.mean(g) for g in flat_grads]))

wandb.log({
"gradients/norm": grad_norm,
"gradients/max": grad_max,
"gradients/mean": grad_mean,
})

# In training loop
grads = compute_gradients(model, batch)
log_gradient_stats(grads)
optimizer.update(grads)

Why this matters:

  • Exploding gradients: norm > 1 → clip or lower LR
  • Vanishing gradients: norm < 0.01 → network too deep or bad init
  • Dead neurons: max near zero → change activation or init

Parameter Statistics

Track parameter evolution:

def log_parameter_stats(model):
"""Log parameter statistics"""

state = nnx.state(model)
params = jax.tree_util.tree_leaves(state)

# Statistics
param_norm = jnp.sqrt(sum(jnp.sum(p**2) for p in params))
param_max = max(jnp.max(jnp.abs(p)) for p in params)
param_mean = jnp.mean(jnp.array([jnp.mean(p) for p in params]))

wandb.log({
"parameters/norm": param_norm,
"parameters/max": param_max,
"parameters/mean": param_mean,
})

# Log every N steps
if step % 100 == 0:
log_parameter_stats(model)

Visualizations

Custom Plots

# Confusion matrix
import seaborn as sns
import matplotlib.pyplot as plt

def log_confusion_matrix(model, val_loader, class_names):
"""Log confusion matrix"""

# Compute predictions
all_preds = []
all_targets = []

for batch in val_loader:
logits = model(batch['images'])
preds = jnp.argmax(logits, axis=-1)
targets = jnp.argmax(batch['labels'], axis=-1)

all_preds.extend(preds)
all_targets.extend(targets)

# Create confusion matrix
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(all_targets, all_preds)

# Plot
fig, ax = plt.subplots(figsize=(10, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names)
ax.set_ylabel('True Label')
ax.set_xlabel('Predicted Label')

# Log to W&B
wandb.log({"confusion_matrix": wandb.Image(fig)})
plt.close()

# Run at end of training
log_confusion_matrix(model, val_loader, ['cat', 'dog', ...])

Image Logging

# Log example predictions
def log_predictions(model, val_batch):
"""Log model predictions on images"""

images = val_batch['images'][:8] # First 8 images
labels = val_batch['labels'][:8]

# Predict
logits = model(images)
preds = jnp.argmax(logits, axis=-1)
probs = jax.nn.softmax(logits, axis=-1)

# Create wandb images with predictions
wandb_images = []
for img, true_label, pred_label, prob in zip(images, labels, preds, probs):
caption = f"True: {true_label}, Pred: {pred_label} ({prob[pred_label]:.2f})"
wandb_images.append(wandb.Image(img, caption=caption))

wandb.log({"predictions": wandb_images})

# Log every few epochs
if epoch % 5 == 0:
log_predictions(model, next(iter(val_loader)))

Histograms

# Log weight distributions
def log_weight_histograms(model):
"""Log parameter distributions"""

state = nnx.state(model)

# Log each layer's weights
for path, param in jax.tree_util.tree_leaves_with_path(state):
name = '.'.join(str(p.key) for p in path if hasattr(p, 'key'))

if 'weight' in name:
wandb.log({
f"weights/{name}": wandb.Histogram(param)
})

# Log periodically
if step % 1000 == 0:
log_weight_histograms(model)

Hyperparameter Sweeps

Automatically search hyperparameter space:

Defining a Sweep

# sweep_config.yaml or in code
sweep_config = {
'method': 'random', # or 'grid', 'bayes'
'metric': {
'name': 'val/accuracy',
'goal': 'maximize'
},
'parameters': {
'learning_rate': {
'distribution': 'log_uniform_values',
'min': 1e-5,
'max': 1e-2
},
'batch_size': {
'values': [32, 64, 128, 256]
},
'num_layers': {
'values': [2, 3, 4, 5]
},
'hidden_size': {
'distribution': 'q_uniform',
'min': 128,
'max': 512,
'q': 64 # Step size
},
'dropout': {
'distribution': 'uniform',
'min': 0.1,
'max': 0.5
},
}
}

# Create sweep
sweep_id = wandb.sweep(sweep_config, project="mnist-classification")

Running Sweep Agents

def train_sweep():
"""Training function for sweep"""

# Initialize with sweep config
run = wandb.init()
config = wandb.config

# Create model with sweep hyperparameters
model = MLP(
in_features=784,
hidden_features=config.hidden_size,
out_features=10,
num_layers=config.num_layers,
dropout_rate=config.dropout,
rngs=nnx.Rngs(params=0)
)

optimizer = nnx.Optimizer(
model,
optax.adam(learning_rate=config.learning_rate)
)

# Train
for epoch in range(config.epochs):
for batch in get_dataloader(batch_size=config.batch_size):
loss, metrics = train_step(model, optimizer, batch)
wandb.log({"train/loss": loss, "train/accuracy": metrics['accuracy']})

# Validation
val_loss, val_acc = evaluate(model, val_loader)
wandb.log({"val/loss": val_loss, "val/accuracy": val_acc})

wandb.finish()

# Run sweep agent
wandb.agent(sweep_id, function=train_sweep, count=50) # Run 50 trials

Sweep Strategies

Random search:

  • Samples hyperparameters randomly
  • Good for exploring large spaces
  • Easy to parallelize

Grid search:

  • Tries all combinations
  • Exhaustive but expensive
  • Best for small spaces

Bayesian optimization:

  • Uses previous results to guide search
  • Most sample-efficient
  • Requires sequential runs

Model Artifacts

Version models and datasets:

Saving Models as Artifacts

# After training
artifact = wandb.Artifact(
name='mnist-cnn',
type='model',
description='CNN trained on MNIST',
metadata={
'accuracy': best_acc,
'architecture': 'CNN',
'params': count_parameters(model)
}
)

# Add model files
artifact.add_file('model.safetensors')
artifact.add_file('config.json')

# Log artifact
wandb.log_artifact(artifact)

Using Artifacts

# Load artifact in new run
run = wandb.init(project="mnist-classification")

artifact = run.use_artifact('mnist-cnn:latest') # Or specific version
artifact_dir = artifact.download()

# Load model
model = load_model_from_checkpoint(f"{artifact_dir}/model.safetensors")

Best Practices

Structuring Experiments

# Good: Organized logging
wandb.log({
# Training metrics
"train/loss": loss,
"train/accuracy": acc,

# Validation metrics
"val/loss": val_loss,
"val/accuracy": val_acc,

# Optimization
"opt/learning_rate": lr,
"opt/gradient_norm": grad_norm,

# System
"system/gpu_memory": gpu_mem,
"system/throughput": samples_per_sec,
})

# Bad: Flat namespace
wandb.log({
"loss": loss,
"loss2": val_loss, # Confusing!
"acc": acc,
"valacc": val_acc, # Inconsistent naming
})

Reproducibility Checklist

Track everything needed to reproduce:

config = {
# Model
"architecture": "ResNet-18",
"num_layers": 18,
"hidden_size": 512,

# Optimization
"optimizer": "AdamW",
"learning_rate": 1e-3,
"weight_decay": 1e-4,
"lr_schedule": "cosine",
"warmup_steps": 1000,

# Data
"dataset": "ImageNet",
"batch_size": 256,
"augmentation": "standard",

# Training
"epochs": 100,
"seed": 42,

# System
"jax_version": jax.__version__,
"flax_version": nnx.__version__,
"device": jax.devices()[0],
}

wandb.init(project="my-project", config=config)

Offline Mode

Train without internet:

# Set offline mode
import os
os.environ['WANDB_MODE'] = 'offline'

# Train normally
wandb.init(project="my-project")
# ... training ...
wandb.finish()

# Later: Sync offline runs
# wandb sync /path/to/offline/run

Common Patterns

Early Stopping

best_val_loss = float('inf')
patience = 5
patience_counter = 0

for epoch in range(max_epochs):
# Training...

# Validation
val_loss = evaluate(model, val_loader)
wandb.log({"val/loss": val_loss})

if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0

# Save best model
save_checkpoint(model, "best_model.safetensors")
wandb.log({"val/best_loss": best_val_loss})
else:
patience_counter += 1

if patience_counter >= patience:
print(f"Early stopping at epoch {epoch}")
break

# Log final best metric
wandb.run.summary["best_val_loss"] = best_val_loss

Multi-Run Comparison

# Run multiple seeds
for seed in [42, 123, 456, 789, 999]:
run = wandb.init(
project="mnist-comparison",
name=f"seed-{seed}",
config={"seed": seed}
)

# Set seed
rngs = nnx.Rngs(params=seed)

# Train
model = train_model(rngs=rngs)

# Log results
val_acc = evaluate(model, val_loader)
wandb.log({"final_val_accuracy": val_acc})

wandb.finish()

# In W&B UI: Compare all runs to see variance

Debugging with W&B

Detecting Issues

Symptoms:

  • Loss explodes → Check gradient norms
  • Loss plateaus → Check learning rate schedule
  • Accuracy stuck → Visualize predictions

Debug dashboard:

# Comprehensive debugging logs
wandb.log({
"debug/loss": loss,
"debug/loss_is_nan": jnp.isnan(loss),
"debug/loss_is_inf": jnp.isinf(loss),

"debug/grad_norm": grad_norm,
"debug/grad_norm_too_large": grad_norm > 10,

"debug/param_norm": param_norm,
"debug/param_max": param_max,

"debug/learning_rate": current_lr,
"debug/batch_mean": batch['images'].mean(),
"debug/batch_std": batch['images'].std(),
})

Next Steps

You now know how to track experiments professionally! Learn more:

Reference Code

Complete modular example: