chapter four

4 Calculating gradients

 

This chapter covers

  • Calculating derivatives in different ways
  • Using autodiff in JAX to calculate gradients of your functions (and neural networks)
  • Using forward and reverse modes of automatic differentiation

Back in chapter 2, we trained a simple neural network for handwritten digit classification. The crucial thing for training any neural network is the ability to calculate a derivative of a loss function with respect to neural network weights.

4.1 Different ways of getting derivatives

4.1.1 Manual differentiation

4.1.2 Symbolic differentiation

4.1.3 Numerical differentiation

4.1.4 Automatic differentiation

4.2 Calculating gradients with autodiff

4.2.1 Working with gradients in TensorFlow

4.2.2 Working with gradients in PyTorch

4.2.3 4.2.3 Working with gradients in JAX

4.2.4 Higher-order derivatives

4.2.5 Multivariable case

4.3 Forward and Reverse mode autodiff

4.3.1 Evaluation trace

4.3.2 Forward mode and jvp()

4.3.3 Reverse mode and vjp()

4.3.4 Going deeper

4.4 Summary