2 Your first program in JAX
This chapter covers
- The high-level structure of a JAX deep learning project
- Loading a dataset
- Creating a simple neural network in JAX
- Using JAX transformations for auto-vectorization, calculating gradients, and Just-in-Time compilation
- Saving and loading a model
- Pure and impure functions
JAX is a library for composable transformations of Python and NumPy programs. Though it is not limited to deep learning research, it is often considered a deep learning framework, sometimes the third after PyTorch and TensorFlow. Therefore, many people start learning JAX for deep learning applications.
In this chapter, we’ll do a deep learning “hello world” exercise. We will build a simple neural network application demonstrating the JAX approach to building a deep learning model. It’s an image classification model that works on the MNIST handwritten digit dataset – a problem you’ve likely seen or addressed with PyTorch or TensorFlow. This project will introduce you to three of the main JAX transformations: grad() for taking gradients, jit() for compilation, and vmap() for auto-vectorization. With just these three transformations, you can build custom neural network solutions that do not need to be distributed on a cluster (for distributed computations, there is a separate pmap() transformation).
The chapter provides an overall, big-picture view, highlighting JAX features and essential concepts. We will explain the details of these concepts in later chapters.