8 Advanced parallelization

 

This chapter covers

  • Using easy-to-revise parallelism with xmap()
  • Compiling and automatically partitioning functions with pjit()
  • Using tensor sharding to achieve parallelization with XLA
  • Running code in multi-host configurations

The goal of the chapter is to introduce you to the more advanced topics in parallelization. Some of them, like xmap() and jax.Array, might be good substitutes for pmap() that make parallelization easier and more robust. Others, like pjit() and jax.Array again, help in training large neural networks. Multi-host environments is a separate topic that is relevant when you have a really large GPU or TPU cluster connecting many separate computing nodes.

We start with exploring the xmap() transformation, which help parallelizing functions easier than with pmap(), with less code, replacing nested pmap() and vmap() calls, and without manual tensor reshaping. It also introduces the named-axis programming model that helps you write more error-proof code. It is still an experimental feature and some changes might happen in the future. If pmap() is enough for you at the moment, and you don’t need this features, you can safely skip this part.

8.1 Using xmap() and the named-axis programming

 
 
 

8.1.1 Working with named axes

 

8.1.2 Parallelism and hardware meshes

 
 
 
 

8.2 Using pjit() for tensor parallelism

 
 
 
 

8.2.1 TPU preparations

 

8.2.2 Basics of pjit()

 
 

8.2.3 MLP example with pjit()

 

8.3 Parallelizing computations with tensor sharding

 
 

8.3.1 Basics

 
 

8.3.2 MLP with tensor sharding

 
 

8.4 Using multi-host configurations

 
 

8.5 Summary

 
 
sitemap

Unable to load book!

The book could not be loaded.

(try again in a couple of minutes)

manning.com homepage
test yourself with a liveTest