This chapter covers
- Working with NumPy arrays
- Working with JAX arrays on CPU/GPU/TPU
- Adapting code to differences between NumPy arrays and JAX arrays
- Using high-level and low-level interfaces: jax.numpy and jax.lax
In the previous chapter, we developed a simple neural network on JAX. With this chapter, we start diving deeper into the JAX core, beginning with arrays (or tensors—we will use these words interchangeably).
The tensor or multidimensional array is the basic data structure in deep learning and scientific computing frameworks. Every program relies on some form of tensor, be it a 1D array, a 2D matrix, or a higher-dimensional array. Handwritten digit images from the previous chapter, intermediate activations, and the resulting network predictions—everything is a tensor. NumPy provides you with the numpy.ndarray type; in JAX, there is an Array type (previously known as DeviceArray).
NumPy arrays (the numpy.ndarray type) and their API became the de facto industry standard many other frameworks respect. JAX provides you with a (mostly) NumPy-compatible API, so the transition from NumPy to JAX should not be hard, and there are many cases when you don’t need to change anything in your code except the import statement. However, some things are different, and we will highlight them.