Recurrent Neural Networks (RNNs)
Traditional feedforward neural networks assume that all inputs (and outputs) are independent of each other. But for many tasks, the sequence matters. Whether it’s predicting the next word in a sentence, analyzing a stock price trend, or understanding a video clip, the order of data points contains crucial information.
This is where Recurrent Neural Networks (RNNs) come in. They have a “memory” which captures information about what has been calculated so far.
1. The Need for Sequence Models
Why can’t we just use a standard Feedforward Neural Network (FNN)?
- Variable Length Inputs/Outputs: Sentences can be 3 words or 30 words long. FNNs typically require fixed-size inputs.
- Parameter Sharing: Learning that “apple” implies a fruit should be applicable whether “apple” appears at the start or end of a sentence.
- Temporal Dependencies: The meaning of a word often depends on the words that came before it.
[!NOTE] An RNN processes sequences by iterating through the sequence elements and maintaining a hidden state containing information relative to the past.
2. RNN Architecture
The core idea of an RNN is the recurrence relation. At each time step t, the hidden state h<sub>t</sub> is updated based on the current input x<sub>t</sub> and the previous hidden state h<sub>t-1</sub>.
The equation is:
h<sub>t</sub> = tanh(W<sub>xh</sub> x<sub>t</sub> + W<sub>hh</sub> h<sub>t-1</sub> + b<sub>h</sub>)
Where:
x<sub>t</sub>is the input vector at timet.h<sub>t</sub>is the hidden state vector at timet.W<sub>xh</sub>is the weight matrix for the input-to-hidden connection.W<sub>hh</sub>is the weight matrix for the hidden-to-hidden connection.b<sub>h</sub>is the bias vector.tanhis the activation function (squashing values between -1 and 1).
The output y<sub>t</sub> is then computed from the hidden state:
y<sub>t</sub> = W<sub>hy</sub> h<sub>t</sub> + b<sub>y</sub>
Interactive: RNN Unroller
See how an RNN processes a sequence step-by-step. The “Unrolled” view shows the same network at different time steps.
3. Backpropagation Through Time (BPTT)
Training an RNN is similar to training a regular neural network, but with a twist. Because the parameters W<sub>xh</sub>, W<sub>hh</sub>, and W<sub>hy</sub> are shared across all time steps, the gradient of the loss function depends on the calculations at previous time steps.
This process is called Backpropagation Through Time (BPTT). It’s essentially standard backpropagation applied to the unrolled computational graph.
[!WARNING] Vanishing Gradient Problem: As gradients are propagated back through many time steps, they tend to vanish (approach zero) or explode (become extremely large).
If the weights are small (e.g., < 1), multiplying them repeatedly causes the gradient to decay exponentially. This makes it difficult for the RNN to learn long-term dependencies (e.g., remembering a subject from the beginning of a long paragraph).
4. PyTorch Implementation
Here is a clean implementation of a vanilla RNN using PyTorch.
import torch
import torch.nn as nn
class SimpleRNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleRNN, self).__init__()
self.hidden_size = hidden_size
# Define the RNN layer
# batch_first=True means input shape is (batch, seq, feature)
self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
# Define the output layer (fully connected)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
# x shape: (batch_size, seq_length, input_size)
# Initialize hidden state with zeros
# Shape: (num_layers, batch_size, hidden_size)
h0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)
# Forward propagate RNN
# out shape: (batch_size, seq_length, hidden_size)
# hn shape: (num_layers, batch_size, hidden_size)
out, hn = self.rnn(x, h0)
# Decode the hidden state of the last time step
out = self.fc(out[:, -1, :])
return out
# Example Usage
input_size = 10
hidden_size = 20
output_size = 1 # e.g., binary classification
seq_length = 5
batch_size = 3
model = SimpleRNN(input_size, hidden_size, output_size)
input_data = torch.randn(batch_size, seq_length, input_size)
output = model(input_data)
print(f"Input shape: {input_data.shape}")
print(f"Output shape: {output.shape}")
5. Summary
- RNNs process sequential data by maintaining a hidden state.
- Parameter Sharing allows the model to apply the same logic at every time step.
- BPTT is used to train RNNs, but suffers from Vanishing Gradients.
- Long-Term Dependencies are hard for vanilla RNNs to capture.
In the next chapter, we will see how LSTMs and GRUs solve the vanishing gradient problem.