This chapter covers
- An introduction to JAX
- When and where to use JAX
- A comparison of JAX with TensorFlow, PyTorch, and NumPy
One more deep learning library? Are you serious?! After everything has converged to the beloved-by-everyone PyTorch and the well-established ecosystem around Tensorflow, why should I bother about JAX? And if I wanted a low-level neural network development, there’s good old NumPy or its alternatives with GPU support. Why would I even want to look at JAX?
The history of deep learning frameworks shows that no framework lasts forever. For instance, where is Theano, which shaped the field significantly? For me, it was the first deep learning library in the modern sense, after the long-forgotten PyBrain2 and others. And where is Caffe? It was hugely popular many years ago, especially for production deployments. Years ago, I developed a driver assistant tool to recognize road signs in real time, running on old Android smartphones with much less power than contemporary phones; Caffe was the best choice for this job back then. And what about good old Torch7, which many image style transfer models used around 2015? I participated in one such project back then, and we used many of these models on our backends. For that matter, where is Chainer, TensorFlow 1, CNTK, Caffe2, and so many others?