chapter three

3 Working with tensors

 

This chapter covers

  • Working with NumPy arrays
  • Working with JAX tensors on CPU/GPU/TPU
  • Adapting code to differences between NumPy arrays and JAX DeviceArray
  • Using high-level and low-level interfaces: jax.numpy and jax.lax

In the previous two chapters we showed what JAX is and why to use it and developed a simple neural network on JAX. With this chapter, we start diving deeper into the JAX core, beginning with tensors.

Tensor or multi-dimensional array is the basic data structure in deep learning and scientific computing frameworks. Every program relies on some form of tensor, be it a 1D array, a 2D matrix, or a higher-dimensional array. Handwritten digit images from the previous chapter, intermediate activations, and the resulting network predictions — everything is a tensor.

NumPy arrays and their API became the de-facto industry standard that many other frameworks respect. This chapter will cover tensors and their respective operations in JAX. We will highlight the differences between NumPy and JAX APIs.

3.1 Image processing with NumPy arrays

Let’s start with a real-life image processing task. Imagine you have a collection of photos you want to process. Some photos have an extra space to crop, others have noise artifacts you want to remove, and many are good, but you want to apply artistic effects to them. For simplicity, let’s focus only on denoising images, as shown in Figure 3.1:

Figure 3.1 Example of image processing we want to implement

3.1.1 Loading and storing images in NumPy arrays

3.1.2 Performing basic image processing with NumPy API

3.2 Tensors in JAX

3.2.1 Switching to JAX NumPy-like API

3.2.2 What is the DeviceArray?

3.2.3 Device-related operations

3.2.4 Asynchronous dispatch

3.2.5 Moving image processing to TPU

3.3 Differences with NumPy

3.3.1 Immutability

3.3.2 Types

3.4 High-level and low-level interfaces: jax.numpy and jax.lax

3.5 Exercises

3.6 Summary