4 Calculating gradients

This chapter covers

  • Calculating derivatives in different ways
  • Using automatic differentiation (autodiff) in JAX to calculate gradients of your functions (and neural networks)
  • Using forward and reverse modes of autodiff

In chapter 2, we trained a simple neural network for handwritten digit classification. The crucial thing for training any neural network is the ability to calculate a derivative of a loss function with respect to neural network weights.

4.1 Different ways of getting derivatives

4.1.1 Manual differentiation

4.1.2 Symbolic differentiation

4.1.3 Numerical differentiation

4.1.4 Automatic differentiation

4.2 Calculating gradients with autodiff

4.2.1 Working with gradients in TensorFlow

4.2.2 Working with gradients in PyTorch

4.2.3 Working with gradients in JAX

4.2.4 Higher-order derivatives

4.2.5 Multivariable case

4.3 Forward- and reverse-mode autodiff

4.3.1 Evaluation trace

4.3.2 Forward mode and jvp()

4.3.3 Reverse mode and vjp()

4.3.4 Going deeper

Summary