6 Vectorizing your code

This chapter covers

  • Approaches to vectorizing your code
  • Controlling vmap() behavior using its parameters
  • Analyzing typical cases where you might benefit from auto-vectorization

In chapter 3, you learned how to speed up your calculations by running them on GPUs and TPUs. Then, in chapter 5, you learned another option to speed up your code with compilation and XLA. Now it’s time to learn two more ways to make computations faster: automatic vectorization and parallelization. This chapter is dedicated to auto-vectorization, while chapters 7 and 8 look at parallelizing your computations.

Auto-vectorization provides you with several benefits. First, it simplifies the programming process by allowing you to write simpler functions for processing a single element and then automatically transform them into more complex functions working on batches (or arrays) of elements. Second, it can speed up your computations if your hardware resources and program logic allow you to perform computations for many items simultaneously. This is typically much faster than processing the same array item by item. It won’t usually be faster than a manually vectorized version (though it won’t be significantly slower either). Still, it will be much faster in another dimension: the developer’s productivity and time to vectorize a function by hand.

6.1 Different ways to vectorize a function

6.1.1 Naive approaches

6.1.2 Manual vectorization

6.1.3 Automatic vectorization

6.1.4 Speed comparisons

6.2 Controlling vmap() behavior

6.2.1 Controlling array axes to map over

6.2.2 Controlling output array axes

6.2.3 Using named arguments

6.2.4 Using decorator style

6.2.5 Using collective operations

6.3 Real-life use cases for vmap()

6.3.1 Batch data processing

6.3.2 Batching neural network models

6.3.3 Per-sample gradients

6.3.4 Vectorizing loops