Self-Attention

[!IMPORTANT] Attention is All You Need. This single concept revolutionized Natural Language Processing by allowing models to focus on relevant parts of the input sequence regardless of distance.

1. The Problem with Recurrence

Before Transformers, Recurrent Neural Networks (RNNs) and LSTMs were the kings of sequence modeling. However, they had a fatal flaw: Sequential Processing.

To understand word t, the network had to process words 0 through t-1. This meant:

  1. No Parallelization: You can’t compute the future before the past. Training was slow.
  2. Long-Term Memory Loss: Information from the beginning of a long sentence often faded away by the end (the “Vanishing Gradient” problem), even with LSTMs.

Self-Attention solves this by looking at all words at once.

2. The Intuition: Query, Key, and Value

The core idea of attention is retrieval. Imagine a database:

  • Query (Q): What you are looking for.
  • Key (K): The labels in the database.
  • Value (V): The actual content.

In Self-Attention, every word in the sentence generates its own Q, K, and V vectors.

  • To compute the representation of a word, we compare its Query with the Keys of all other words (including itself).
  • The match score (dot product) determines how much “attention” to pay.
  • We use these scores to compute a weighted sum of the Values.

[!TIP] Think of it as a “soft” hash map lookup. Instead of returning one value for a key match, it returns a blend of all values, weighted by how well their keys matched the query.

3. The Math: Scaled Dot-Product Attention

The formula for Scaled Dot-Product Attention is:

Attention(Q, K, V) = softmax( (Q × KT) / √dk ) × V

Where:

  • Q: Matrix of Queries (Shape: N × dk)
  • K: Matrix of Keys (Shape: N × dk)
  • V: Matrix of Values (Shape: N × dv)
  • dk: Dimension of the key vectors.
  • √dk: Scaling factor to prevent gradients from vanishing in the softmax.

4. Interactive: Self-Attention Visualizer

Explore how words attend to each other. Hover over a word to see its attention distribution.

Note: This is a simulated attention pattern for demonstration. Real weights are learned during training.

5. Implementation in PyTorch

Let’s build ScaledDotProductAttention from scratch.

[!NOTE] This function is the building block for Multi-Head Attention, which we’ll cover in the next chapter.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class ScaledDotProductAttention(nn.Module):
  def __init__(self, d_k):
    super().__init__()
    self.scale = 1.0 / math.sqrt(d_k)

  def forward(self, q, k, v, mask=None):
    """
    Args:
      q: Query tensor (batch_size, n_heads, seq_len_q, d_k)
      k: Key tensor   (batch_size, n_heads, seq_len_k, d_k)
      v: Value tensor (batch_size, n_heads, seq_len_v, d_v)
      mask: Optional mask tensor (batch_size, 1, 1, seq_len_k)
          Usually used for decoder causal masking (look-ahead mask).

    Returns:
      context: Context vectors (batch_size, n_heads, seq_len_q, d_v)
      attn: Attention weights (batch_size, n_heads, seq_len_q, seq_len_k)
    """

    # 1. Matmul Q and K^T
    # shape: (batch_size, n_heads, seq_len_q, seq_len_k)
    attn_scores = torch.matmul(q, k.transpose(-2, -1))

    # 2. Scale
    attn_scores = attn_scores * self.scale

    # 3. Apply Mask (Optional)
    if mask is not None:
      # Replace masked positions with -infinity (so softmax makes them 0)
      attn_scores = attn_scores.masked_fill(mask == 0, -1e9)

    # 4. Softmax
    # shape: (batch_size, n_heads, seq_len_q, seq_len_k)
    attn_weights = F.softmax(attn_scores, dim=-1)

    # 5. Matmul with V
    # shape: (batch_size, n_heads, seq_len_q, d_v)
    context = torch.matmul(attn_weights, v)

    return context, attn_weights

# Example Usage
d_k = 64
seq_len = 10
batch_size = 2
n_heads = 1 # Simplified for single head example

# Create random tensors
q = torch.randn(batch_size, n_heads, seq_len, d_k)
k = torch.randn(batch_size, n_heads, seq_len, d_k)
v = torch.randn(batch_size, n_heads, seq_len, d_k) # d_v usually equals d_k

attention = ScaledDotProductAttention(d_k)
output, weights = attention(q, k, v)

print(f"Output Shape: {output.shape}")
# Output Shape: torch.Size([2, 1, 10, 64])
print(f"Attention Weights Shape: {weights.shape}")
# Attention Weights Shape: torch.Size([2, 1, 10, 10])

6. Computational Complexity

Why is Self-Attention sometimes slow?

  • Complexity: O(N2 ċ d)
  • For every token, we compute attention with every other token.
  • If sequence length N doubles, computation quadruples.
  • Comparison with RNN: O(N ċ d2)
  • RNN is linear in sequence length but quadratic in hidden size.
  • Self-Attention is faster for short sequences but struggles with very long documents (hence models like Longformer).

In the next chapter, we will see how Multi-Head Attention combines multiple independent attention mechanisms to capture different types of relationships.