This chapter covers
- Calculating derivatives in different ways
- Using automatic differentiation (autodiff) in JAX to calculate gradients of your functions (and neural networks)
- Using forward and reverse modes of autodiff
In chapter 2, we trained a simple neural network for handwritten digit classification. The crucial thing for training any neural network is the ability to calculate a derivative of a loss function with respect to neural network weights.