chapter four
4 Calculating gradients
This chapter covers
- Calculating derivatives in different ways
- Using autodiff in JAX to calculate gradients of your functions (and neural networks)
- Using forward and reverse modes of automatic differentiation
Back in chapter 2, we trained a simple neural network for handwritten digit classification. The crucial thing for training any neural network is the ability to calculate a derivative of a loss function with respect to neural network weights.