2 Your first program in JAX

This chapter covers

  • The high-level structure of a JAX deep learning project
  • Loading a dataset
  • Creating a simple neural network in JAX
  • Using JAX transformations for auto-vectorization, calculating gradients, and just-in-time compilation
  • Saving and loading a model
  • Pure and impure functions

JAX is a library for composable transformations of Python and NumPy programs. Though it is not limited to deep learning research, it is often considered a deep learning framework, sometimes the third after PyTorch and TensorFlow. Therefore, many people start by learning JAX for deep learning applications.

In this chapter, we’ll do a deep learning Hello World exercise. We will build a simple neural network application demonstrating the JAX approach to building a deep learning model. It’s an image classification model that works on the MNIST handwritten digit dataset—a problem you’ve likely seen or addressed with PyTorch or TensorFlow. This project will introduce you to three of the main JAX transformations: grad() for taking gradients, jit() for compilation, and vmap() for auto-vectorization. With just these three transformations, you can build custom neural network solutions that do not need to be distributed on a cluster (for distributed computations, there is a separate pmap() transformation).

The chapter provides an overall, big-picture view, highlighting JAX features and essential concepts. I will explain the details of these concepts in later chapters.

2.1 A toy ML problem: Classifying handwritten digits

 
 
 

2.2 An overview of a JAX deep learning project

 
 
 

2.3 Loading and preparing the dataset

 
 
 

2.4 A simple neural network in JAX

 
 
 

2.4.1 Neural network initialization

 
 

2.4.2 Neural network forward pass

 
 

2.5 vmap: Auto-vectorizing calculations to work with batches

 
 

2.6 Autodiff: How to calculate gradients without knowing about derivatives

 
 

2.6.1 Loss function

 
 

2.6.2 Obtaining gradients

 
 
 

2.6.3 Gradient update step

 
 
 

2.6.4 Training loop

 
 

2.7 JIT: Compiling your code to make it faster

 

2.8 Saving and deploying the model

 
 
 

2.9 Pure functions and composable transformations: Why are they important?

 
 
sitemap

Unable to load book!

The book could not be loaded.

(try again in a couple of minutes)

manning.com homepage
test yourself with a liveTest