3 Classify images with a vision transformer (ViT)

 

This chapter covers

  • Dividing an image into patches of tokens
  • Training a transformer to predict the next image token
  • Classifying CIFAR-10 images using a vision transformer (ViT)
  • Visualizing how a ViT pays attention to different parts of an image

In transformer-based text-to-image generation, a pivotal step is converting an image into a sequence of tokens, much like how we process words in a sentence in natural language. This is where vision transformers (ViTs) come in. ViTs, introduced by Google researchers in their landmark 2020 paper “An Image is Worth 16x16 Words,” brought the power of transformer architectures, originally designed for natural language, to the world of computer vision.[1] Their innovation is to treat images as sequences of patches, enabling transformers to excel at image classification and beyond.

This chapter guides you through the core ideas and practical implementation of ViTs. You’ll build a ViT from scratch and train it to classify images from the widely used CIFAR-10 dataset. The process begins by splitting each image into an 8×8 grid, resulting in 64 patches. Each patch is treated as a token in the transformer’s input sequence. This clever adaptation enables us to bring the strengths of transformers, such as self-attention and long-range dependency modeling, into computer vision tasks.

3.1 The blueprint to train a vision transformer

3.1.1 How to convert images to sequences

3.1.2 How to train a vision transformer for classification

3.2 The CIFAR-10 dataset

3.2.1 Download and visualize CIFAR-10 images

3.2.2 Prepare datasets for training and testing

3.3 Build a vision transformer (ViT) from scratch

3.3.1 Divide images into patches

3.3.2 Model the positions of different patches in an Image

3.3.3 The multi-head self-attention mechanism

3.3.4 Build an encoder-only transformer

3.3.5 Use the vision transformer to create a classifier

3.4 Train and use the vision transformer to classify images

3.4.1 Choose the optimizer and the loss function

3.4.2 Train the vision transformer for image classification

3.4.3 Classify images using the trained ViT

3.5 Summary