chapter five

5 Compiling your code

 

This chapter covers

  • Using Just-in-Time (JIT) compilation to produce performant code for CPU, GPU, or TPU
  • Looking at JIT internals: intermediate representations and accelerated linear algebra compiler
  • Dealing with JIT limitations

In Chapter 1, we compared the performance of a simple JAX function on a CPU and GPU, with and without JIT. In Chapter 2, we used JIT to compile two functions in a training loop for a simple neural network. So, you basically know what JIT does. It compiles your function for a target hardware platform and makes it faster.

Beginning in Chapter 4, we started learning JAX transformations (remember, JAX is about composable function transformations!). That chapter taught us about autodiff and the grad() transformation. This chapter will teach about compilation and the corresponding jit() transformation. In the following chapters, we will learn more transformations related to auto-vectorization and parallelization.

In the landscape of numerical computing, JAX emerged as a compelling framework rooted in the foundational strength of Google’s XLA compiler. The XLA compiler is not just another tool in the computational toolkit; it is specifically tailored to produce efficient code for high-performance computing tasks. Think of it as an architect meticulously designing blueprints to build structures optimized for their purpose.

5.1 Using compilation

5.1.1 Using Just-in-Time (JIT) compilation

5.1.2 Pure functions and compilation process

5.2 JIT internals

5.2.1 Jaxpr, an intermediate representation for JAX programs

5.2.2 XLA

5.2.3 Using Ahead-of-Time (AOT) compilation

5.3 JIT limitations

5.4 Exercise

5.5 Summary