chapter two

2 Your first program in JAX

 

This chapter covers

  • The MNIST handwritten digit classification problem
  • Loading a dataset in JAX
  • Creating a simple neural network in JAX
  • Auto-vectorizing code with vmap() function
  • Calculating gradients with grad() function
  • Jist-in-Time compilation with jit() function
  • Pure and impure functions
  • The high-level structure of a JAX deep learning project

In the previous chapter, we learned about JAX and its importance. We also described the JAX features that make it so powerful. This chapter will give you a practical understanding of JAX.

JAX is a library for composable transformations of Python and NumPy programs, and it is technically not limited to deep learning research. However, JAX is still considered a deep learning framework, sometimes the third after PyTorch and TensorFlow. Therefore many people start learning JAX for deep learning applications. So, a simple neural network application that shows the JAX approach to the problem is very valuable for many.

2.1 A toy ML problem: classifying handwritten digits

2.2 Loading and preparing the dataset

2.3 A simple neural network in JAX

2.3.1 Neural network initialization

2.3.2 Neural network forward pass

2.4 vmap: auto-vectorizing calculations to work with batches

2.5 Autodiff: how to calculate gradients without knowing about derivatives

2.6 JIT: compiling your code to make it faster

2.7 Pure functions and composable transformations: why is it important?

2.8 An overview of a JAX deep learning project

2.9 Exercises

2.10 Summary