7 Parallelizing your computations
This chapter covers
- Using parallel evaluation to parallelize your calculations with
pmap() - Controlling
pmap()behavior using its parameters - Implementing data parallel neural network training
In this chapter, we continue our exploration of JAX transformations. Here we start diving into parallelization, or running your computations on multiple devices at the same time in parallel.
There are several mechanisms to parallelize your computations in JAX. The pmap() is the most straightforward, and we'll discuss it in this chapter. (In the next chapter, we'll look at the other mechanisms: xmap(), pjit(), tensor sharding, and will use pmap() in a more complicated multi-host configuration on a TPU Pod slice.)
The pmap() transformation, or the parallel map has an interface that is similar to the vmap(), so it is natural to start with it. The pmap() uses a so-called single-program multiple-data (SPMD) parallelism. In SPMD you run the same program on multiple devices. The idea is that you split your data into chunks, and each device processes its own chunk of data simultaneously using the same code. This way you can process more data at the same time just adding more devices (however, the scaling is usually sub-linear because of the communication overhead).