Appendix D. Experimental Parallelization
This chapter covers
- Using easy-to-revise parallelism with xmap()
- Compiling and automatically partitioning functions with pjit()
In this Appendix, we gathered two (and a half) experimental parallelization techniques, namely xmap() (plus the half, shmap()), and pjit().
The xmap() is an older technique that will be deprecated soon. However, it still might be interesting for those who need to understand legacy code or who want to understand the evolution of parallelization in JAX better.
The xmap() transformation helps parallelize functions easier than with pmap(), with less code, replacing nested pmap() and vmap() calls, and without manual tensor reshaping. It also introduces the named-axis programming model that helps you write more error-proof code.
At some point in time, xmap() stopped being actively developed in favor of pjit() (the next topic of our story). Despite its deprecated status, it is very logical to describe it here, as xmap() provides a very natural way of generalizing pmap() and vmap().
If you don’t need this feature and there is no legacy code you need to support, you can safely skip this part.
One of the possible xmap() replacements comes from the JAX ecosystem -- the Haliax library for building neural networks with named tensors. Another alternative comes from the core JAX -- the shmap(), which now has the status of JAX Enhancement Proposals (JEPs) and is likely a replacement for the xmap().