5 Compiling your code

 

This chapter covers

  • Using Just-in-Time (JIT) compilation to produce performant code for CPU, GPU, or TPU
  • Looking at JIT internals: jaxpr, the JAX intermediate language, and HLO, the High Level Operations Intermediate Representation of XLA, Google’s Accelerated Linear Algebra compiler
  • Dealing with JIT limitations

In the previous chapter we learned about autodiff and the grad() transformation. In this chapter we will learn about compilation and another very useful transformation, jit().

JAX uses Google’s XLA compiler to compile and produce efficient code. XLA is the backend that powers machine learning frameworks, originally TensorFlow, on various devices, including CPUs, GPUs, and TPUs. JAX uses compilation under the hood for library calls, but you can also use it for just-in-time (JIT) compiling your Python functions with the jit() function transformation. JIT compilation optimizes the computation graph and can fuse a sequence of operations into a single efficient computation or eliminate some redundant computations. It improves performance even on the CPU.

In this chapter, we will cover the mechanics of JIT, learn to use it efficiently, and understand its limitations.

5.1 Using compilation

 
 
 

5.1.1 Using Just-in-Time (JIT) compilation

 
 

5.1.2 Pure functions

 
 
 

5.2 JIT internals

 
 
 

5.2.1 Jaxpr, an intermediate representation for JAX programs

 
 
 
 

5.2.2 XLA

 
 

5.2.3 Using Ahead-of-Time (AOT) compilation

 
 
 
 

5.3 JIT limitations

 
 
 

5.4 Summary

 
 
sitemap

Unable to load book!

The book could not be loaded.

(try again in a couple of minutes)

manning.com homepage
test yourself with a liveTest