Appendix A. Installing JAX

 

This chapter covers

  • How to install JAX on your system, CPU, GPU or TPU

JAX is published as two separate Python packages:

  • jax, a pure Python package
  • jaxlib, a mostly-C++ package that contains libraries such as XLA, pieces of LLVM used by XLA, MLIR infrastructure with MHLO Python bindings, and JAX-specific C++ libraries for fast JIT and PyTree manipulation.

JAX installation process will differ depending on your target architecture, be it a CPU, GPU, or TPU.

A.1 Installing JAX on CPU

JAX is designed for high-performance computing and especially shines on TPU or GPU. Though, thanks to the XLA compiler, you still get a boost even on the CPU. You may also want to use CPU installation for local development.

The easiest way to install JAX on the CPU is to use pip, the package installer for Python.

pip install --upgrade pip
pip install --upgrade "jax[cpu]"

The current release of jaxlib supports the following platforms and architectures:

  • Linux x86_64
  • Mac x86_64
  • Mac ARM
  • Windows x86_64, native or using WSL2, Windows Subsystem for Linux.

On Windows, you may also need to install the Microsoft Visual Studio 2019 Redistributable if it is not already installed on your machine. Please consult the official documentation for more details: https://jax.readthedocs.io/en/latest/installation.html#cpu.

A.2 Installing JAX on GPU

 
 

A.2.1 pip installation with CUDA

 
 
 

A.2.2 pip installation with self-installed CUDA/CuDNN

 

A.2.3 Using Docker containers

 
 
 

A.3 Installing JAX on TPU

 
 
sitemap

Unable to load book!

The book could not be loaded.

(try again in a couple of minutes)

manning.com homepage
test yourself with a liveTest