This chapter covers
- The high-level structure of a JAX deep learning project
- Loading a dataset
- Creating a simple neural network in JAX
- Using JAX transformations for auto-vectorization, calculating gradients, and just-in-time compilation
- Saving and loading a model
- Pure and impure functions
JAX is a library for composable transformations of Python and NumPy programs. Though it is not limited to deep learning research, it is often considered a deep learning framework, sometimes the third after PyTorch and TensorFlow. Therefore, many people start by learning JAX for deep learning applications.
In this chapter, we’ll do a deep learning Hello World exercise. We will build a simple neural network application demonstrating the JAX approach to building a deep learning model. It’s an image classification model that works on the MNIST handwritten digit dataset—a problem you’ve likely seen or addressed with PyTorch or TensorFlow. This project will introduce you to three of the main JAX transformations: grad() for taking gradients, jit() for compilation, and vmap() for auto-vectorization. With just these three transformations, you can build custom neural network solutions that do not need to be distributed on a cluster (for distributed computations, there is a separate pmap() transformation).
The chapter provides an overall, big-picture view, highlighting JAX features and essential concepts. I will explain the details of these concepts in later chapters.