Appendix A. Installing JAX
This chapter covers
- How to install JAX on your system, CPU, GPU or TPU
JAX is published as two separate Python packages:
- jax, a pure Python package
- jaxlib, a mostly-C++ package that contains libraries such as XLA, pieces of LLVM used by XLA, MLIR infrastructure with MHLO Python bindings, and JAX-specific C++ libraries for fast JIT and PyTree manipulation.
JAX installation process will differ depending on your target architecture, be it a CPU, GPU, or TPU.
A.1 Installing JAX on CPU
JAX is designed for high-performance computing and especially shines on TPU or GPU. Though, thanks to the XLA compiler, you still get a boost even on the CPU. You may also want to use CPU installation for local development.
The easiest way to install JAX on the CPU is to use pip, the package installer for Python.
pip install --upgrade pip pip install --upgrade "jax[cpu]"
The current release of jaxlib supports the following platforms and architectures:
- Linux x86_64
- Mac x86_64
- Mac ARM
- Windows x86_64, native or using WSL2, Windows Subsystem for Linux.
On Windows, you may also need to install the Microsoft Visual Studio 2019 Redistributable if it is not already installed on your machine. Please consult the official documentation for more details: https://jax.readthedocs.io/en/latest/installation.html#cpu.