This chapter covers
- Building an MLP for MNIST digit classification using Flax and its Linen API
- Using the Optax gradient transformation library for model training
- Using the TrainState dataclass for representing a training state and storing metrics
- Building a residual neural network for image classification and working with model state variables
- Using Hugging Face libraries with JAX/Flax transformers and diffusers
Core JAX is a powerful but pretty low-level library. Just as you will rarely ever build a complex neural network in pure NumPy or with basic TensorFlow primitives, in most cases, you will also not do so with pure JAX. And just as there are higher-level neural network libraries for TensorFlow (Keras, Sonnet) and PyTorch (torch.nn, Pytorch Lightning, fast.ai), there are also such libraries for JAX.