Transformer Architecture

[!IMPORTANT] The Transformer is not just one attention layer. It is a deep stack of Encoder and Decoder blocks, carefully designed to flow gradients effectively.

1. The Big Picture: Encoder-Decoder

The original Transformer (Vaswani et al., 2017) was designed for machine translation. It consists of two stacks:

  1. Encoder: Processes the input sequence (e.g., English sentence) into a contextualized representation.
  2. Decoder: Generates the output sequence (e.g., French sentence) one token at a time, attending to the Encoder’s output.

Architecture Diagram

Encoder Stack (Nx) Feed Forward Add & Norm Multi-Head Attn Decoder Stack (Nx) Feed Forward Masked Multi-Head Attn

2. Multi-Head Attention

Why have one attention head when you can have many?

Multi-Head Attention (MHA) splits the embedding dimension dmodel into h heads (e.g., 512 → 8 × 64).

  • Parallelism: Each head learns different relationships (e.g., one focuses on grammar, another on subject-verb agreement).
  • Concatenation: Outputs are concatenated and projected back to dmodel.

Formula:

MultiHead(Q, K, V) = Concat(head1, ..., headh)WO

3. Positional Encoding

Transformers process all tokens simultaneously. Without Positional Encodings, the model would see “The dog bit the man” and “The man bit the dog” as identical “bags of words”.

We add a fixed vector to each input embedding to represent its position.

  • Even dimensions: sin(pos / 100002i/dmodel)
  • Odd dimensions: cos(pos / 100002i/dmodel)

Interactive: Positional Encoding Explorer

See how the waves create a unique fingerprint for every position.

X-axis: Dimension Index | Y-axis: Position Index. Red = +1, Blue = -1. Note the frequency decay from left to right.

4. Implementation: Multi-Head Attention

Here is how we assemble ScaledDotProductAttention into a Multi-Head layer.

class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, n_heads):
    super().__init__()
    assert d_model % n_heads == 0, "d_model must be divisible by n_heads"

    self.d_k = d_model // n_heads
    self.n_heads = n_heads
    self.d_model = d_model

    # Linear projections for Q, K, V, and Output
    self.w_q = nn.Linear(d_model, d_model)
    self.w_k = nn.Linear(d_model, d_model)
    self.w_v = nn.Linear(d_model, d_model)
    self.w_o = nn.Linear(d_model, d_model)

    self.attention = ScaledDotProductAttention(self.d_k)

  def forward(self, q, k, v, mask=None):
    batch_size = q.size(0)

    # 1. Linear Projections and Split Heads
    # Reshape: (batch, seq_len, n_heads, d_k) -> Transpose: (batch, n_heads, seq_len, d_k)
    q = self.w_q(q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
    k = self.w_k(k).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
    v = self.w_v(v).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)

    # 2. Apply Attention
    # context: (batch, n_heads, seq_len, d_k)
    context, attn_weights = self.attention(q, k, v, mask)

    # 3. Concatenate Heads
    # Transpose: (batch, seq_len, n_heads, d_k) -> Reshape: (batch, seq_len, d_model)
    context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)

    # 4. Final Linear Projection
    output = self.w_o(context)

    return output, attn_weights

5. Add & Norm (Residual Connections)

Each sub-layer (Attention, Feed-Forward) is wrapped in a residual connection followed by Layer Normalization:

Output = LayerNorm(x + Sublayer(x))

This allows gradients to flow through the network without vanishing, enabling the training of very deep models (like GPT-4).

6. Summary

  • Encoder: Extracts features (Bi-directional context).
  • Decoder: Generates tokens (Uni-directional / Causal context).
  • Positional Encoding: Injects order information.
  • Multi-Head Attention: Captures diverse relationships in parallel.

In the next chapter, we will explore how these architectures are trained using objectives like Masked Language Modeling.