11 Higher-level neural network libraries

This chapter covers

  • Building an MLP for MNIST digit classification using Flax and its Linen API
  • Using the Optax gradient transformation library for model training
  • Using the TrainState dataclass for representing a training state and storing metrics
  • Building a residual neural network for image classification and working with model state variables
  • Using Hugging Face libraries with JAX/Flax transformers and diffusers

Core JAX is a powerful but pretty low-level library. Just as you will rarely ever build a complex neural network in pure NumPy or with basic TensorFlow primitives, in most cases, you will also not do so with pure JAX. And just as there are higher-level neural network libraries for TensorFlow (Keras, Sonnet) and PyTorch (torch.nn, Pytorch Lightning, fast.ai), there are also such libraries for JAX.

11.1 MNIST image classification using an MLP

11.1.1 MLP in Flax

11.1.2 Optax gradient transformations library

11.1.3 Training a neural network the Flax way

11.2 Image classification using a ResNet

11.2.1 Managing state in Flax

11.2.2 Saving and loading a model using Orbax

11.3 Using the Hugging Face ecosystem

11.3.1 Using a pretrained model from the Hugging Face Model Hub

11.3.2 Going further with fine-tuning and pretraining

11.3.3 Using the diffusers library

Summary