Causal Attention
Causal attention, also known as masked attention, is a self-attention mechanism used in large language models (LLMs) that process and generate text sequentially from left to right. This mechanism ensures that a model only considers previous and current inputs in a sequence, effectively preventing future tokens from influencing the current token. This is crucial for tasks such as language modeling, where predictions should be based solely on past and present information.
Overview
In causal attention, future tokens are masked to prevent information leakage. This is achieved by zeroing out the attention weights above the diagonal in the attention matrix, ensuring that only past and present tokens contribute to the computation of context vectors. The use of causal attention is essential in autoregressive models, where the model generates text one token at a time.
Implementation
The implementation of causal attention involves creating a mask that zeroes out the attention weights for future tokens. This is typically done using PyTorch’s tril
function to create a lower triangular mask. The masked attention weights are then renormalized to ensure that the distribution of attention weights is calculated only among unmasked positions.
Compact Causal Attention Class
Below is a compact implementation of a causal attention class in PyTorch:
class CausalAttention(nn.Module):
def __init__(self, d_in, d_out, context_length,
dropout, qkv_bias=False):
super().__init__()
self.d_out = d_out
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.dropout = nn.Dropout(dropout) #1
self.register_buffer(
'mask',
torch.triu(torch.ones(context_length, context_length),
diagonal=1)
) #2
def forward(self, x):
b, num_tokens, d_in = x.shape #3
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
attn_scores = queries @ keys.transpose(1, 2)
attn_scores.masked_fill_( #4
self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, dim=-1
)
attn_weights = self.dropout(attn_weights)
context_vec = attn_weights @ values
return context_vec
In this implementation, a dropout layer is added to prevent overfitting by randomly zeroing out some of the attention weights. The register_buffer
call is used to store the mask, which is applied to the attention scores before computing the softmax.
Figures
Masking Attention Weights
Figure 3.19 In causal attention, we mask out the attention weights above the diagonal such that for a given input, the LLM can’t access future tokens when computing the context vectors using the attention weights. For example, for the word “journey” in the second row, we only keep the attention weights for the words before (“Your”) and in the current position (“journey”).
Obtaining Masked Attention Weights
Figure 3.20 One way to obtain the masked attention weight matrix in causal attention is to apply the softmax function to the attention scores, zeroing out the elements above the diagonal and normalizing the resulting matrix.
Efficient Masking with Softmax
Figure 3.21 A more efficient way to obtain the masked attention weight matrix in causal attention is to mask the attention scores with negative infinity values before applying the softmax function.
In Figure 3.21, the use of negative infinity values ensures that the softmax function treats these positions as having zero probability, effectively masking them out before normalization. This approach is computationally efficient and aligns with the mathematical properties of the softmax function.