Backward in PyTorch
In PyTorch, the backward
function is a crucial component of the autograd system, which is responsible for automatic differentiation. This function is used to compute the gradient of a tensor with respect to some scalar value, typically a loss. The gradients are then used to update the model parameters during training.
How backward
Works
When you call backward
on a loss tensor, PyTorch computes the gradients of the loss with respect to all tensors with requires_grad=True
that were used to compute the loss. This is done by traversing the autograd graph in reverse, starting from the loss node. The autograd graph is constructed during the forward pass, where each operation is recorded as a node in the graph.
Example of Using backward
Here is an example of a training loop that uses backward
to compute gradients:
def training_loop(n_epochs, learning_rate, params, t_u, t_c):
for epoch in range(1, n_epochs + 1):
if params.grad is not None:
params.grad.zero_()
t_p = model(t_u, *params)
loss = loss_fn(t_p, t_c)
loss.backward()
with torch.no_grad():
params -= learning_rate * params.grad
if epoch % 500 == 0:
print('Epoch %d, Loss %f' % (epoch, float(loss)))
return params
In this code, loss.backward()
computes the gradients of the loss with respect to params
. The gradients are stored in params.grad
, which is then used to update the parameters.
Accumulating Gradients
When backward
is called, the gradients are accumulated into the .grad
attribute of each leaf node. This means that if you call backward
multiple times, the gradients will be summed. Therefore, it is important to zero the gradients at the start of each iteration, as shown in the example with params.grad.zero_()
.
Visualizing Gradient Propagation
The process of gradient computation can be visualized using graphs. Figure 5.11 illustrates the forward and backward graphs of a model as computed with autograd. The forward graph is constructed during the forward pass, and the backward graph is traversed during the backward pass to compute gradients.
Figure 5.11 The forward graph and backward graph of the model as computed with autograd.
Additionally, Figure 5.16 shows how gradients propagate through a graph with two losses when backward
is called on one of them. This demonstrates the selective computation of gradients in a complex graph.
Figure 5.16 Diagram showing how gradients propagate through a graph with two losses when .backward is called on one of them.
Practical Considerations
When updating parameters, it is common to use a no_grad
context to prevent PyTorch from tracking the operations on the parameters. This is because the forward graph is consumed when backward
is called, and we want to modify the parameters without affecting the graph. This is shown in the training loop example with the with torch.no_grad()
block.
In summary, the backward
function is a powerful tool in PyTorch for computing gradients, which are essential for training neural networks. Understanding how to use it effectively is key to implementing efficient and correct training loops.
FAQ (Frequently asked questions)
How do gradients propagate through a graph with two losses in PyTorch?
What happens when .backward is called on one of the losses in a graph with two losses?