This chapter covers
- Just-in-time (JIT) compilation to produce performant code for CPU, GPU, or TPU
- JIT internals: intermediate representations and accelerated linear algebra compilers
- JIT limitations
In chapter 1, we compared the performance of a simple JAX function on a CPU and GPU, with and without JIT. In chapter 2, we used JIT to compile two functions in a training loop for a simple neural network. So you basically know what JIT does. It compiles your function for a target hardware platform and makes it faster.
Beginning in chapter 4, we started learning JAX transformations (remember, JAX is about composable function transformations!). That chapter taught us about autodiff and the grad() transformation. This chapter will discuss compilation and the corresponding jit() transformation. In the following chapters, we will learn more transformations related to auto-vectorization and parallelization.
In the landscape of numerical computing, JAX emerged as a compelling framework rooted in the foundational strength of Google’s XLA compiler. The XLA compiler is not just another tool in the computational toolkit; it is specifically tailored to produce efficient code for high-performance computing tasks. Think of it as an architect meticulously designing blueprints to build structures optimized for their purpose.