Causal Attention

Causal attention, also known as masked attention, is a self-attention mechanism used in large language models (LLMs) that process and generate text sequentially from left to right. This mechanism ensures that a model only considers previous and current inputs in a sequence, effectively preventing future tokens from influencing the current token. This is crucial for tasks such as language modeling, where predictions should be based solely on past and present information.

Overview

In causal attention, future tokens are masked to prevent information leakage. This is achieved by zeroing out the attention weights above the diagonal in the attention matrix, ensuring that only past and present tokens contribute to the computation of context vectors. The use of causal attention is essential in autoregressive models, where the model generates text one token at a time.

Implementation

The implementation of causal attention involves creating a mask that zeroes out the attention weights for future tokens. This is typically done using PyTorch’s tril function to create a lower triangular mask. The masked attention weights are then renormalized to ensure that the distribution of attention weights is calculated only among unmasked positions.

Compact Causal Attention Class

Below is a compact implementation of a causal attention class in PyTorch:

class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length,
                dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)            #1
        self.register_buffer(
           'mask',
           torch.triu(torch.ones(context_length, context_length),
           diagonal=1)
        )             #2

    def forward(self, x):
        b, num_tokens, d_in = x.shape                   #3
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)   
        attn_scores.masked_fill_(                    #4
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) 
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values
        return context_vec

In this implementation, a dropout layer is added to prevent overfitting by randomly zeroing out some of the attention weights. The register_buffer call is used to store the mask, which is applied to the attention scores before computing the softmax.

Figures

Masking Attention Weights

[Figure 3.19](https://livebook.manning.com/build-a-large-language-model-from-scratch/chapter-3/figure--3-19) In causal attention, we mask out the attention weights above the diagonal such that for a given input, the LLM can’t access future tokens when computing the context vectors using the attention weights. For example, for the word “journey” in the second row, we only keep the attention weights for the words before (“Your”) and in the current position (“journey”). Figure 3.19 In causal attention, we mask out the attention weights above the diagonal such that for a given input, the LLM can’t access future tokens when computing the context vectors using the attention weights. For example, for the word “journey” in the second row, we only keep the attention weights for the words before (“Your”) and in the current position (“journey”).

Obtaining Masked Attention Weights

[Figure 3.20](https://livebook.manning.com/build-a-large-language-model-from-scratch/chapter-3/figure--3-20) One way to obtain the masked attention weight matrix in causal attention is to apply the softmax function to the attention scores, zeroing out the elements above the diagonal and normalizing the resulting matrix. Figure 3.20 One way to obtain the masked attention weight matrix in causal attention is to apply the softmax function to the attention scores, zeroing out the elements above the diagonal and normalizing the resulting matrix.

Efficient Masking with Softmax

[Figure 3.21](https://livebook.manning.com/build-a-large-language-model-from-scratch/chapter-3/figure--3-21) A more efficient way to obtain the masked attention weight matrix in causal attention is to mask the attention scores with negative infinity values before applying the softmax function. Figure 3.21 A more efficient way to obtain the masked attention weight matrix in causal attention is to mask the attention scores with negative infinity values before applying the softmax function.

In Figure 3.21, the use of negative infinity values ensures that the softmax function treats these positions as having zero probability, effectively masking them out before normalization. This approach is computationally efficient and aligns with the mathematical properties of the softmax function.

sitemap

Unable to load book!

The book could not be loaded.

(try again in a couple of minutes)

manning.com homepage
test yourself with a liveTest