This chapter covers
- Using parallel evaluation to parallelize your calculations with pmap()
- Controlling pmap() behavior using its parameters
- Implementing data-parallel neural network training
- Running code in multihost configurations
In this chapter, we continue our exploration of JAX transformations. Here, we start diving into parallelization or running your computations on multiple devices simultaneously in parallel. This is especially relevant when you are doing large-scale neural network training, weather or ocean simulation, and any other task where at least part of the computations do not depend on each other and may be done in parallel. If that’s the case, you can perform the whole computation faster in terms of time spent.
There are several mechanisms to parallelize your computations in JAX. The pmap() is the most straightforward, as here you explicitly control how your computation is done; we’ll discuss it in this chapter. In the next chapter, we’ll look at tensor sharding, the new and easy way of achieving parallelization implicitly, letting the compiler automatically partition your functions over devices. In appendix D, we additionally cover two experimental and partly outdated mechanisms, xmap(), and pjit(), for those who either are interested in the historical development of parallelization techniques in JAX or need to work with legacy code, which uses these techniques.