Part 2. Core JAX
Dive deeper into the mechanics of JAX in part 2, where we explore the core features that make JAX a formidable tool for deep learning and scientific computing. Spanning eight chapters, this section provides a thorough examination of JAX’s capabilities, from working with arrays and calculating gradients to compiling, vectorizing, and parallelizing your code. Each chapter focuses on a fundamental aspect of JAX, illustrated with practical examples and in-depth discussions to solidify your understanding and skills.
Chapters 3 through 10 are meticulously designed to guide you through the intricacies of JAX, ensuring a mastery of its most powerful features. You’ll start by working with arrays (chapter 3), to understand the nuances that differentiate JAX from NumPy and learn how to use these differences to your advantage. As you progress, you’ll delve into calculating gradients (chapter 4) using JAX’s automatic differentiation capabilities to simplify and accelerate the training of neural networks. Chapters 5 and 6 introduce you to JAX’s just-in-time compilation and auto-vectorization, revealing strategies to significantly boost performance.