8 Using tensor sharding

This chapter covers

  • Using tensor sharding to achieve parallelization with XLA
  • Implementing data and tensor parallelism for training neural networks

The chapter introduces you to an alternative and modern way of parallelizing computations in JAX using tensor sharding. The use case is the same as in the previous chapter: to run some parts of the computation in parallel and perform the whole computation faster. It is especially useful for different ways of parallelizing neural network training, be it data or model parallelism. It can also be applied to inference with large models that do not fit into a single GPU. However, areas other than deep learning can also benefit from this modern technique. If you work with large tensors in bioinformatics, cosmology, weather modeling, or elsewhere, tensor sharding can provide you with an easy way to parallelize your computations.

Parallelization with pmap() from the previous chapter gives you the ability to explicitly tell the compiler what you want to do using per-device code and explicit communication collectives. Another school of thought lets the compiler automatically partition functions over devices without specifying too many low-level details. Tensor sharding (or distributed arrays) belongs to the second school. This option to parallelize computations has been available since JAX version 0.4.1, together with the new jax.Array type.

8.1 Basics of tensor sharding

8.1.1 Device mesh

8.1.2 Positional sharding

8.1.3 An example with 2D mesh

8.1.4 Using replication

8.1.5 Sharding constraints

8.1.6 Named sharding

8.1.7 Device placement policy and errors

8.2 MLP with tensor sharding

8.2.1 Eight-way data parallelism

8.2.2 Four-way data parallelism, two-way tensor parallelism

Summary