Skip to main content

Welcome to Learn Flax NNX Training

Your comprehensive guide to mastering neural network training with Flax NNX and JAX from the ground up! 🚀

What is Flax NNX?

Flax NNX is the new neural network API built on JAX, combining the best of functional and object-oriented programming. It provides a flexible, high-performance framework for machine learning research and production:

  • Pythonic & Intuitive: Easy to learn, familiar OOP style with explicit state management
  • Explicit RNGs: No hidden randomness, full control over reproducibility
  • JIT Compilation: Lightning-fast execution with JAX's XLA compiler
  • Automatic Differentiation: Effortless gradient computation for any function
  • Scalability: Easy scaling from single GPU to TPU pods with minimal code changes
  • Production Ready: Used in real-world systems at Google and beyond

Why Learn Flax NNX?

Unlike older neural network libraries that hide state management and randomness, Flax NNX gives you explicit control over every aspect of your models. This makes debugging easier, scaling simpler, and helps you truly understand what's happening in your training loop.

This guide will teach you the concepts and patterns behind Flax NNX training, not just show you code to copy. You'll learn:

  • How Flax NNX manages model state and parameters
  • Why explicit RNG handling makes your training reproducible
  • When to use different optimization patterns
  • What makes a good training loop architecture

Documentation Structure

This documentation is organized into small, focused guides that won't overwhelm you:

🎯 Fundamentals

Start with the core concepts that apply everywhere:

  • Your First Model: Build a simple neural network from scratch
  • Understanding State: How NNX manages parameters and variables

These fundamentals take ~15 minutes and are essential for everything else.

Start with Fundamentals →

🏃 Training Workflows

Learn the practical skills to train models:

  • Simple Training Loop: Write your first complete training loop
  • Data Loading: Build efficient data pipelines without bottlenecks

Short, focused guides that get you training quickly.

Learn Training Workflows →

🖼️ Computer Vision

Build image models step-by-step:

  • Simple CNN: Your first convolutional network for image classification
  • ResNet: Deep networks with skip connections

Each guide is self-contained and builds one complete model.

Explore Computer Vision →

📝 Natural Language Processing

Build text models from scratch:

  • Simple Transformer: Understand attention and build GPT-style models

Clear explanations of how transformers actually work.

Explore NLP →

📈 Scale

Take your training to production scale:

  • Distributed Training: Multiple GPUs and TPUs
  • Performance Optimization: Make training faster

Learn about Scaling →

🔬 Research

Advanced patterns for cutting-edge research:

  • Model Export: ONNX, SafeTensors, HuggingFace formats
  • Observability: Track experiments with W&B
  • Advanced Architectures: Building ResNets, Transformers, BERT, and GPT from scratch

Explore Research Topics →

How to Use This Documentation

If you're brand new:

  1. Start with Fundamentals → (~15 min)
  2. Learn Training Workflows → (~20 min)
  3. Choose your domain: Vision or Text

If you know the basics:

  • Jump directly to Computer Vision or NLP
  • Each guide is self-contained and buildable in isolation

If you need specific examples:

  • See the /examples directory
  • 20 complete, organized examples covering all topics:
    • Basics (5 examples): Model definition, checkpointing, data loading
    • Training (2 examples): Vision and language model training
    • Export (1 example): Model deployment formats
    • Integrations (3 examples): HuggingFace, W&B, streaming data
    • Advanced (5 examples): BERT, GPT, SimCLR, MAML, distillation
    • Distributed (4 examples): Multi-device training strategies
  • All examples use shared, tested components for consistency
  • View the complete index: python examples/index.py

What Makes This Different?

Small, focused guides: Each page teaches ONE concept completely. No 5000-word mega-guides.

Domain-organized: Vision models in vision/, text models in text/. Find what you need quickly.

Example-driven: Every concept has working code you can run immediately.

No overwhelm: Start simple, build up gradually. You won't drown in complexity.

Reference Code

All documentation includes conceptual explanations with code snippets. For complete runnable examples, see the /examples directory in the repository:

  • 20 organized examples in modular structure:
    • basics/ - Model definition, checkpointing, data loading
    • training/ - End-to-end vision and language model training
    • export/ - Model deployment (SafeTensors, ONNX)
    • integrations/ - HuggingFace Hub, W&B, streaming datasets
    • advanced/ - BERT, GPT, SimCLR, MAML, knowledge distillation
    • distributed/ - pmap, SPMD, pipeline parallelism, FSDP
  • Shared component library (shared/models.py, shared/training_utils.py) with tested, reusable code
  • Each file is extensively commented for learning
  • Run python examples/index.py to see all available examples with descriptions

Getting Help

Contributing

We welcome contributions! If you'd like to improve this documentation:

  1. Fork the repository
  2. Make your changes
  3. Submit a pull request

Happy training with Flax! 🎉