Transformer Architecture
The Transformer is not just one attention layer. It is a deep stack of Encoder and Decoder blocks, carefully designed to flow gradients effectively.
1. The Big Picture: Encoder-Decoder
The original Transformer (Vaswani et al., 2017) was designed for machine translation. It abandons recurrence (like LSTMs) entirely, relying strictly on attention mechanisms to draw global dependencies between inputs and outputs. The architecture consists of two primary stacks:
- Encoder: Processes the input sequence (e.g., an English sentence) in parallel, transforming it into a rich, contextualized mathematical representation.
- Decoder: Generates the output sequence (e.g., a French sentence) autoregressively (one token at a time), continuously attending to the Encoder’s output and its own previously generated tokens.
Imagine a translation agency. The Encoder is a team of analysts who read the entire English document at once. They don't just read words; they draw lines connecting pronouns to their nouns and verbs to their subjects, creating a heavily annotated "context map". The Decoder is the translator who writes the French version word by word. Before writing the next word, the translator looks at their own previously written French words (Causal Attention) and consults the Encoder's context map (Cross-Attention) to decide what comes next.
Architecture Diagram
2. Multi-Head Attention
Before understanding Multi-Head, we must understand the core operation: Query, Key, Value (Q, K, V) attention.
- Query (Q): What I am looking for (e.g., “I am a verb, looking for my subject”).
- Key (K): What I have (e.g., “I am a noun, acting as a subject”).
- Value (V): What I actually mean (the actual semantic content transferred if a match occurs).
Imagine searching a library database. Your search string is the Query. The database compares your Query against the titles/metadata of all books, which are the Keys. The alignment between your Query and the Keys determines an attention score. Finally, the actual text inside the books you retrieve are the Values.
Why have one attention head when you can have many?
Multi-Head Attention (MHA) splits the embedding dimension dmodel into h heads (e.g., 512 → 8 heads of 64 dimensions each).
- Parallelism: Each head learns different representation subspaces. One head might focus entirely on grammar (subject-verb agreement), while another focuses on semantic relationships (e.g., “bank” to “river” vs “bank” to “money”).
- Concatenation: The outputs of all heads are concatenated and projected back to dmodel.
Hardware Reality: GPU Efficiency
Unlike Recurrent Neural Networks (RNNs) which process tokens sequentially (token $t$ waits for token $t-1$), Multi-Head Attention processes all tokens simultaneously. The Q, K, and V transformations are simply massive matrix multiplications, which modern GPUs are highly optimized to execute in parallel. This is the primary reason Transformers dominate large-scale training.
3. The Feed-Forward Network (FFN)
Attention tells tokens where to look, but the Position-wise Feed-Forward Network (FFN) determines what to think about it.
After the attention mechanism, each token’s representation is passed through an identical, independent FFN. This is typically a two-layer multi-layer perceptron (MLP) with a ReLU or GELU activation in between.
- Dimensionality Expansion: The hidden layer of the FFN is typically 4x larger than the model dimension (e.g., 512 → 2048 → 512).
- Purpose: The attention layer aggregates information from other tokens. The FFN processes this newly aggregated information locally, acting as the model’s “memory bank” for facts and learned transformations.
4. Positional Encoding
Transformers process all tokens simultaneously. Because there is no sequential processing, the model natively has no concept of order. Without Positional Encodings, the model would see “The dog bit the man” and “The man bit the dog” as identical “bags of words”.
To fix this, we add a fixed vector to each input embedding to represent its precise position in the sequence. The original paper uses sine and cosine functions of different frequencies:
- Even dimensions: $\sin(pos / 10000^{2i/d_{model}})$
- Odd dimensions: $\cos(pos / 10000^{2i/d_{model}})$
Interactive: Positional Encoding Explorer
See how the intertwined sine and cosine waves create a unique mathematical fingerprint for every position. Notice that nearby positions have similar patterns, allowing the model to learn relative distances easily.
X-axis: Dimension Index | Y-axis: Position Index. Red = +1, Blue = -1. Note the frequency decay from left to right.
5. Add & Norm (Residual Connections & Layer Normalization)
Each sub-layer (Multi-Head Attention, Feed-Forward Network) is wrapped in a residual connection followed by Layer Normalization.
- Residual Connections ($x + …$): Provide a “shortcut” for gradients during backpropagation, bypassing complex layers. This prevents the vanishing gradient problem, enabling the training of extremely deep networks (like GPT-4 with 96+ layers).
- Layer Normalization: Stabilizes the network by keeping the outputs of each sub-layer centered with a consistent variance.
Senior Depth: Pre-LN vs. Post-LN
The original Transformer utilized Post-Layer Normalization (Post-LN), where normalization occurs after the residual addition. While effective, it suffers from severe instability early in training, requiring a complex “learning rate warmup” phase.
Modern models (like GPT-2, GPT-3, and Llama) use Pre-Layer Normalization (Pre-LN). In Pre-LN, normalization is applied before the sub-layer.
Pre-LN provides significantly smoother gradients at initialization, removing the strict need for learning rate warmups and allowing for more stable scaling to massive parameters.
6. Implementation: Multi-Head Attention
Here is how we assemble ScaledDotProductAttention into a complete PyTorch Multi-Head layer. Pay attention to how the dimensions are elegantly reshaped to process all heads simultaneously.
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
self.d_k = d_model // n_heads
self.n_heads = n_heads
self.d_model = d_model
# Linear projections for Q, K, V, and Output
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
# Assuming ScaledDotProductAttention is defined elsewhere
self.attention = ScaledDotProductAttention(self.d_k)
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
# 1. Linear Projections and Split Heads
# Reshape: (batch, seq_len, d_model) -> (batch, seq_len, n_heads, d_k)
# Transpose to group by heads: (batch, n_heads, seq_len, d_k)
q = self.w_q(q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
k = self.w_k(k).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
v = self.w_v(v).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
# 2. Apply Attention independently for each head
# context shape: (batch, n_heads, seq_len, d_k)
context, attn_weights = self.attention(q, k, v, mask)
# 3. Concatenate Heads
# Transpose back: (batch, seq_len, n_heads, d_k)
# Reshape to flatten heads: (batch, seq_len, d_model)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# 4. Final Linear Projection to mix head outputs
output = self.w_o(context)
return output, attn_weights
7. Summary
- Encoder: Extracts rich contextual features using bidirectional multi-head attention.
- Decoder: Generates tokens autoregressively, masking future tokens and attending to the Encoder’s output.
- Multi-Head Attention: Captures diverse relationships in parallel, highly optimized for GPU matrix multiplication.
- Feed-Forward Network: Acts as the local “memory bank” for each token.
- Positional Encoding: Injects crucial order information into the permutation-invariant attention layers.
- Add & Norm: Residual connections and Layer Normalization stabilize the network and enable deep scaling.
In the next chapter, we will explore how these architectures are trained using objectives like Masked Language Modeling.