chapter one

1 Intro to JAX

 

This chapter covers

  • What is JAX, and how does it compare to NumPy
  • Why use JAX?
  • Comparing JAX with TensorFlow/PyTorch

JAX is gaining popularity as more and more researchers start using it for their research and large companies such as DeepMind contribute to its ecosystem.

In this chapter, we will introduce JAX and its powerful ecosystem. We will explain what JAX is and how it relates to NumPy, PyTorch, and TensorFlow. We will go through JAX's strengths to understand how they combine, giving you a very powerful tool for deep learning research and high-performance computing.

1.1 What is JAX?

JAX is a Python mathematics library with a NumPy interface developed by Google (the Google Brain team, to be specific). It is heavily used for machine learning research, but it is not limited to it, and many other things can be solved with JAX.

JAX creators describe it as Autograd and XLA. Do not be afraid if you are unfamiliar with these names; it’s normal, especially if you are just getting into the field.

Autograd (https://github.com/hips/autograd) is the library that efficiently computes derivatives of NumPy code, the predecessor of JAX. By the way, the Autograd library's main developers are now working on JAX. In a few words, Autograd means you can automatically calculate gradients for your computations, which is the essence of deep learning and many other fields, including numerical optimization, physics simulations, and, more generally, differentiable programming.

1.1.1 JAX as NumPy

1.1.2 Composable transformations

1.2 Why use JAX?

1.2.1 Computational performance

1.2.2 Functional approach

1.2.3 JAX ecosystem

1.3 How is JAX different from TensorFlow/PyTorch?

1.4 Summary