chapter four

4 Training and verifying neural nets in raw CUDA

 

This chapter covers

  • The challenge of building and verifying a neural network from scratch to achieve reproducibility.
  • Establishing a "gold standard" implementation in PyTorch.
  • Using NumPy and C to peel back layers of abstraction.
  • Building a naive CUDA MLP from our C blueprint.
  • Optimizing the CUDA implementation with the cuBLAS library.
  • Techniques for debugging and profiling low-level GPU code.

4.1 Level 1: The PyTorch Approach

In our journey from high-level frameworks to bare-metal CUDA, we need a reliable point of reference: a "gold standard" against which we can verify every subsequent implementation. For this, we turn to PyTorch. Its combination of a high-level API, automatic differentiation, and highly-optimized, battle-tested backend libraries makes it the perfect candidate for our trusted implementation.

The goal in this section is not just to build a PyTorch model, but to build it in a way that is reproducible and comparable to our future low-level versions. This PyTorch implementation will be our benchmark for correctness (do we get the same loss?) and performance. The full code for this section is in the v1.py script.

Before we start our PyTorch implementation, let’s introduce the problem we’ll be solving. Our task is to classify handwritten digits using the famous MNIST dataset, often considered the "hello world" of computer vision, as shown in Figure 4.1.

4.1.1 PyTorch Implementation: Setup, Training, and Benchmarking

4.2 Level 2: Backpropagation Theory

4.2.1 Building Intuition: From Slopes to Gradients

4.2.2 Scaling to Vectors and Matrices

4.2.3 Layer-by-Layer Gradient Derivations

4.2.4 Common Pitfalls and Debugging Tips

4.2.5 Vectorized Implementation for Batches

4.3 Level 3: The NumPy Approach

4.3.1 Manual Implementation in NumPy

4.3.2 Verification Technique: Isolate and Conquer

4.3.3 Performance

4.4 Level 4: The C Approach

4.4.1 Setup: Headers, Constants, and Data Structures

4.4.2 Memory Management and Initialization

4.4.3 The Forward Pass in C

4.4.4 The Backward Pass in C