chapter seven

7 Parallelizing your computations

 

This chapter covers

  • Using parallel evaluation to parallelize your calculations with pmap()
  • Controlling pmap() behavior using its parameters
  • Implementing data parallel neural network training

In this chapter, we continue our exploration of JAX transformations. Here we start diving into parallelization, or running your computations on multiple devices at the same time in parallel.

There are several mechanisms to parallelize your computations in JAX. The pmap() is the most straightforward, and we'll discuss it in this chapter. (In the next chapter, we'll look at the other mechanisms: xmap(), pjit(), tensor sharding, and will use pmap() in a more complicated multi-host configuration on a TPU Pod slice.)

The pmap() transformation, or the parallel map has an interface that is similar to the vmap(), so it is natural to start with it. The pmap() uses a so-called single-program multiple-data (SPMD) parallelism. In SPMD you run the same program on multiple devices. The idea is that you split your data into chunks, and each device processes its own chunk of data simultaneously using the same code. This way you can process more data at the same time just adding more devices (however, the scaling is usually sub-linear because of the communication overhead).

7.1 Parallelizing computations with pmap()

7.1.1 Setting up a problem

7.1.2 Using pmap (almost) like vmap

7.2 Controlling pmap() behavior

7.2.1 Controlling input and output mapping axes

7.2.2 Using names axes and collectives

7.3 Data parallel neural network training example

7.3.1 Preparing data and neural network structure

7.3.2 Implementing data parallel training

7.4 Summary