1 When and why to use JAX

 

This chapter covers

  • What is JAX?
  • When and where to use JAX
  • Comparing JAX with TensorFlow, PyTorch, and NumPy

One more deep learning library? Are you serious?! After everything has converged to the beloved-by-everyone PyTorch and the well-established ecosystem around Tensorflow, why should I bother about JAX? And if I wanted a low-level neural network development, there's good old NumPy or its alternatives with GPU support. Why would I even want to look at JAX?

The history of deep learning frameworks shows that no framework lasts forever. For instance, where is Theano, which shaped the field significantly? For me, it was the first deep learning library in the modern sense, after the long-forgotten PyBrain2 and others. And where is Caffe? It was hugely popular many years ago, especially for production deployments. Years ago, I developed a driver assistant tool to recognize road signs in real-time, running on old Android smartphones with much less power than contemporary phones; Caffe was the best choice for this job back then. And what about good old Torch7, which many image style transfer models used around 2015? I participated in one such project back then, and we used many of these models on our backends. For that matter, where is Chainer, TensorFlow 1, CNTK, Caffe2, and so many others?

1.1 Reasons to use JAX

1.1.1 Computational performance

1.1.2 Functional approach

1.1.3 JAX ecosystem

1.2 How is JAX different from NumPy?

1.2.1 JAX as NumPy

1.2.2 Composable transformations

1.3 How is JAX different from TensorFlow and PyTorch?

1.4 Summary

sitemap