11 Higher-level neural network libraries

This chapter covers

  • Building an MLP for MNIST digit classification using Flax and its Linen API
  • Using the Optax gradient transformation library for model training
  • Using the TrainState dataclass for representing a training state and storing metrics
  • Building a residual neural network for image classification and working with model state variables
  • Using Hugging Face libraries with JAX/Flax transformers and diffusers

Core JAX is a powerful but pretty low-level library. Just as you will rarely ever build a complex neural network in pure NumPy or with basic TensorFlow primitives, in most cases, you will also not do so with pure JAX. And just as there are higher-level neural network libraries for TensorFlow (Keras, Sonnet) and PyTorch (torch.nn, Pytorch Lightning, fast.ai), there are also such libraries for JAX.

11.1 MNIST image classification using an MLP

 
 
 

11.1.1 MLP in Flax

 

11.1.2 Optax gradient transformations library

 
 
 

11.1.3 Training a neural network the Flax way

 
 

11.2 Image classification using a ResNet

 
 

11.2.1 Managing state in Flax

 
 
 

11.2.2 Saving and loading a model using Orbax

 
 

11.3 Using the Hugging Face ecosystem

 
 

11.3.1 Using a pretrained model from the Hugging Face Model Hub

 

11.3.2 Going further with fine-tuning and pretraining

 
 
 
 

11.3.3 Using the diffusers library

 
 
 

Summary

 
 
sitemap

Unable to load book!

The book could not be loaded.

(try again in a couple of minutes)

manning.com homepage
test yourself with a liveTest