Welcome to Learn Flax NNX Training
Your comprehensive guide to mastering neural network training with Flax NNX and JAX from the ground up! 🚀
What is Flax NNX?
Flax NNX is the new neural network API built on JAX, combining the best of functional and object-oriented programming. It provides a flexible, high-performance framework for machine learning research and production:
- Pythonic & Intuitive: Easy to learn, familiar OOP style with explicit state management
- Explicit RNGs: No hidden randomness, full control over reproducibility
- JIT Compilation: Lightning-fast execution with JAX's XLA compiler
- Automatic Differentiation: Effortless gradient computation for any function
- Scalability: Easy scaling from single GPU to TPU pods with minimal code changes
- Production Ready: Used in real-world systems at Google and beyond
Why Learn Flax NNX?
Unlike older neural network libraries that hide state management and randomness, Flax NNX gives you explicit control over every aspect of your models. This makes debugging easier, scaling simpler, and helps you truly understand what's happening in your training loop.
This guide will teach you the concepts and patterns behind Flax NNX training, not just show you code to copy. You'll learn:
- How Flax NNX manages model state and parameters
- Why explicit RNG handling makes your training reproducible
- When to use different optimization patterns
- What makes a good training loop architecture
Documentation Structure
This documentation is organized into small, focused guides that won't overwhelm you:
🎯 Fundamentals
Start with the core concepts that apply everywhere:
- Your First Model: Build a simple neural network from scratch
- Understanding State: How NNX manages parameters and variables
These fundamentals take ~15 minutes and are essential for everything else.
🏃 Training Workflows
Learn the practical skills to train models:
- Simple Training Loop: Write your first complete training loop
- Data Loading: Build efficient data pipelines without bottlenecks
Short, focused guides that get you training quickly.
🖼️ Computer Vision
Build image models step-by-step:
- Simple CNN: Your first convolutional network for image classification
- ResNet: Deep networks with skip connections
Each guide is self-contained and builds one complete model.
📝 Natural Language Processing
Build text models from scratch:
- Simple Transformer: Understand attention and build GPT-style models
Clear explanations of how transformers actually work.
📈 Scale
Take your training to production scale:
- Distributed Training: Multiple GPUs and TPUs
- Performance Optimization: Make training faster
🔬 Research
Advanced patterns for cutting-edge research:
- Model Export: ONNX, SafeTensors, HuggingFace formats
- Observability: Track experiments with W&B
- Advanced Architectures: Building ResNets, Transformers, BERT, and GPT from scratch
How to Use This Documentation
If you're brand new:
- Start with Fundamentals → (~15 min)
- Learn Training Workflows → (~20 min)
- Choose your domain: Vision or Text
If you know the basics:
- Jump directly to Computer Vision or NLP
- Each guide is self-contained and buildable in isolation
If you need specific examples:
- See the
/examplesdirectory - 20 complete, organized examples covering all topics:
- Basics (5 examples): Model definition, checkpointing, data loading
- Training (2 examples): Vision and language model training
- Export (1 example): Model deployment formats
- Integrations (3 examples): HuggingFace, W&B, streaming data
- Advanced (5 examples): BERT, GPT, SimCLR, MAML, distillation
- Distributed (4 examples): Multi-device training strategies
- All examples use shared, tested components for consistency
- View the complete index:
python examples/index.py
What Makes This Different?
Small, focused guides: Each page teaches ONE concept completely. No 5000-word mega-guides.
Domain-organized: Vision models in vision/, text models in text/. Find what you need quickly.
Example-driven: Every concept has working code you can run immediately.
No overwhelm: Start simple, build up gradually. You won't drown in complexity.
Reference Code
All documentation includes conceptual explanations with code snippets. For complete runnable examples, see the /examples directory in the repository:
- 20 organized examples in modular structure:
basics/- Model definition, checkpointing, data loadingtraining/- End-to-end vision and language model trainingexport/- Model deployment (SafeTensors, ONNX)integrations/- HuggingFace Hub, W&B, streaming datasetsadvanced/- BERT, GPT, SimCLR, MAML, knowledge distillationdistributed/- pmap, SPMD, pipeline parallelism, FSDP
- Shared component library (
shared/models.py,shared/training_utils.py) with tested, reusable code - Each file is extensively commented for learning
- Run
python examples/index.pyto see all available examples with descriptions
Getting Help
- GitHub Issues: Report bugs or request features in our GitHub repository
- Flax Official Docs: Check out the official Flax documentation
- JAX Documentation: Learn more about JAX
Contributing
We welcome contributions! If you'd like to improve this documentation:
- Fork the repository
- Make your changes
- Submit a pull request
Happy training with Flax! 🎉