5 Compiling your code
This chapter covers
- Using Just-in-Time (JIT) compilation to produce performant code for CPU, GPU, or TPU
- Looking at JIT internals: jaxpr, the JAX intermediate language, and HLO, the High Level Operations Intermediate Representation of XLA, Google’s Accelerated Linear Algebra compiler
- Dealing with JIT limitations
In the previous chapter we learned about autodiff and the grad() transformation. In this chapter we will learn about compilation and another very useful transformation, jit().
JAX uses Google’s XLA compiler to compile and produce efficient code. XLA is the backend that powers machine learning frameworks, originally TensorFlow, on various devices, including CPUs, GPUs, and TPUs. JAX uses compilation under the hood for library calls, but you can also use it for just-in-time (JIT) compiling your Python functions with the jit() function transformation. JIT compilation optimizes the computation graph and can fuse a sequence of operations into a single efficient computation or eliminate some redundant computations. It improves performance even on the CPU.
In this chapter, we will cover the mechanics of JIT, learn to use it efficiently, and understand its limitations.