chapter two
2 Your first program in JAX
This chapter covers
- The MNIST handwritten digit classification problem
- Loading a dataset in JAX
- Creating a simple neural network in JAX
- Auto-vectorizing code with vmap() function
- Calculating gradients with grad() function
- Jist-in-Time compilation with jit() function
- Pure and impure functions
- The high-level structure of a JAX deep learning project
In the previous chapter, we learned about JAX and its importance. We also described the JAX features that make it so powerful. This chapter will give you a practical understanding of JAX.
JAX is a library for composable transformations of Python and NumPy programs, and it is technically not limited to deep learning research. However, JAX is still considered a deep learning framework, sometimes the third after PyTorch and TensorFlow. Therefore many people start learning JAX for deep learning applications. So, a simple neural network application that shows the JAX approach to the problem is very valuable for many.