chapter two

2 Your first program in JAX

 

This chapter covers

  • The high-level structure of a JAX deep learning project
  • Loading a dataset
  • Creating a simple neural network in JAX
  • Using JAX transformations for auto-vectorization, calculating gradients, and Just-in-Time compilation
  • Saving and loading a model
  • Pure and impure functions

JAX is a library for composable transformations of Python and NumPy programs. Though it is not limited to deep learning research, it is often considered a deep learning framework, sometimes the third after PyTorch and TensorFlow. Therefore, many people start learning JAX for deep learning applications.

In this chapter, we’ll do a deep learning “hello world” exercise. We will build a simple neural network application demonstrating the JAX approach to building a deep learning model. It’s an image classification model that works on the MNIST handwritten digit dataset – a problem you’ve likely seen or addressed with PyTorch or TensorFlow. This project will introduce you to three of the main JAX transformations: grad() for taking gradients, jit() for compilation, and vmap() for auto-vectorization. With just these three transformations, you can build custom neural network solutions that do not need to be distributed on a cluster (for distributed computations, there is a separate pmap() transformation).

The chapter provides an overall, big-picture view, highlighting JAX features and essential concepts. We will explain the details of these concepts in later chapters.

2.1 A toy ML problem: classifying handwritten digits

2.2 An overview of a JAX deep learning project

2.3 Loading and preparing the dataset

2.4 A simple neural network in JAX

2.4.1 Neural network initialization

2.4.2 2.4.2 Neural network forward pass

2.5 vmap: auto-vectorizing calculations to work with batches

2.6 Autodiff: how to calculate gradients without knowing about derivatives

2.7 JIT: compiling your code to make it faster

2.8 Saving and deploying the model

2.9 Pure functions and composable transformations: why is it important?

2.10 Exercises

2.11 Summary