chapter four

4 Autodiff

 

This chapter covers

  • Calculating derivatives in different ways
  • Calculating gradients of your functions with the grad() transformation
  • Using forward and reverse modes with jvp() and vjp() transformations

Chapter 3 showed us how to work with tensors which is essential for almost any deep learning or scientific computing application. In Chapter 2, we also trained a simple neural network for handwritten digit classification. The crucial thing for training a neural network is the ability to calculate a derivative of a loss function with respect to neural network weights.

There are several ways of getting derivatives of your functions (or differentiating them), automatic differentiation (or autodiff for short) being the main one in modern deep learning frameworks. Autodiff is one of the JAX framework pillars. It enables you to write your Python and NumPy-like code, leaving the hard and tricky part of getting derivatives to the framework. In this chapter, we will cover autodiff essentials and dive into how autodiff works in JAX.

4.1 Different ways of getting derivatives

4.1.1 Manual differentiation

4.1.2 Symbolic differentiation

4.1.3 Numerical differentiation

4.1.4 Automatic differentiation

4.2 Calculating gradients with autodiff

4.2.1 Working with gradients in TensorFlow

4.2.2 Working with gradients in PyTorch

4.2.3 Working with gradients in JAX

4.2.4 Higher-order derivatives

4.2.5 Multivariable case

4.3 Forward and Reverse mode autodiff

4.3.1 Evaluation trace

4.3.2 Forward mode and jvp()

4.3.3 Reverse mode and vjp()

4.3.4 Going deeper

4.4 Summary