LSTMs and GRUs

Standard RNNs suffer from short-term memory. If a sequence is long enough, they’ll have a hard time carrying information from earlier time steps to later ones. For example, if you are trying to predict the last word in “I grew up in France… I speak fluent French”, the RNN needs to remember “France” from way back to predict “French”.

Long Short-Term Memory (LSTM) networks and Gated Recurrent Units (GRUs) are specialized kinds of RNNs, capable of learning long-term dependencies.

1. Long Short-Term Memory (LSTM)

The key to LSTMs is the cell state (often denoted as C<sub>t</sub>). It acts like a conveyor belt, running straight down the entire chain, with only some minor linear interactions. It’s very easy for information to flow along it unchanged.

LSTMs have the ability to remove or add information to the cell state, carefully regulated by structures called gates. Gates are a way to optionally let information through. They are composed of a sigmoid neural net layer and a pointwise multiplication operation.

The Gates

  1. Forget Gate: Decides what information we’re going to throw away from the cell state.
    • f<sub>t</sub> = &sigma;(W<sub>f</sub> &middot; [h<sub>t-1</sub>, x<sub>t</sub>] + b<sub>f</sub>)
  2. Input Gate: Decides what new information we’re going to store in the cell state.
    • i<sub>t</sub> = &sigma;(W<sub>i</sub> &middot; [h<sub>t-1</sub>, x<sub>t</sub>] + b<sub>i</sub>)
    • &Ctilde;<sub>t</sub> = tanh(W<sub>C</sub> &middot; [h<sub>t-1</sub>, x<sub>t</sub>] + b<sub>C</sub>)
  3. Cell State Update:
    • C<sub>t</sub> = f<sub>t</sub> * C<sub>t-1</sub> + i<sub>t</sub> * &Ctilde;<sub>t</sub>
  4. Output Gate: Decides what we’re going to output.
    • o<sub>t</sub> = &sigma;(W<sub>o</sub> &middot; [h<sub>t-1</sub>, x<sub>t</sub>] + b<sub>o</sub>)
    • h<sub>t</sub> = o<sub>t</sub> * tanh(C<sub>t</sub>)

Interactive: LSTM Cell Explorer

Adjust the gates to see how they control the flow of information.

  • Forget Gate: 0 means completely forget the old state, 1 means keep it.
  • Input Gate: 0 means ignore new input, 1 means add it fully.
  • Output Gate: 0 means output nothing, 1 means output the activated cell state.

Gate Controls

0.5
0.8
0.5
0.5
0.5

State Calculation

Ct = (ft * Ct-1) + (it * &Ctilde;t)
Ct = 0.25 + 0.40 = 0.65
ht = ot * tanh(Ct)
ht = 0.5 * 0.57 = 0.29
Ct-1
Ct
ht

2. Gated Recurrent Units (GRU)

A slightly more dramatic variation on the LSTM is the Gated Recurrent Unit, or GRU. It combines the forget and input gates into a single “update gate”. It also merges the cell state and hidden state, and makes some other changes. The resulting model is simpler than standard LSTM models.

The equations are:

  • Reset Gate: r<sub>t</sub> = &sigma;(W<sub>r</sub> &middot; [h<sub>t-1</sub>, x<sub>t</sub>])
  • Update Gate: z<sub>t</sub> = &sigma;(W<sub>z</sub> &middot; [h<sub>t-1</sub>, x<sub>t</sub>])
  • New Memory: &htilde;<sub>t</sub> = tanh(W &middot; [r<sub>t</sub> * h<sub>t-1</sub>, x<sub>t</sub>])
  • Final Memory: h<sub>t</sub> = (1 - z<sub>t</sub>) * h<sub>t-1</sub> + z<sub>t</sub> * &htilde;<sub>t</sub>

[!TIP] Which one to use?

  • LSTM: More powerful, more parameters. Good default.
  • GRU: Simpler, faster to train. Good for smaller datasets or when compute is limited.

3. PyTorch Implementation

PyTorch makes it incredibly easy to switch between RNN, LSTM, and GRU. You essentially just change the class name.

import torch
import torch.nn as nn

class LSTMModel(nn.Module):
  def __init__(self, input_size, hidden_size, num_layers, output_size):
    super(LSTMModel, self).__init__()
    self.hidden_size = hidden_size
    self.num_layers = num_layers

    # LSTM layer
    self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)

    # Fully connected layer
    self.fc = nn.Linear(hidden_size, output_size)

  def forward(self, x):
    # Initialize hidden state and cell state
    h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
    c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)

    # Forward propagate LSTM
    # out: tensor of shape (batch_size, seq_length, hidden_size)
    out, (hn, cn) = self.lstm(x, (h0, c0))

    # Decode the hidden state of the last time step
    out = self.fc(out[:, -1, :])
    return out

# Example Usage
input_size = 10
hidden_size = 20
num_layers = 2
output_size = 1
seq_length = 5
batch_size = 3

model = LSTMModel(input_size, hidden_size, num_layers, output_size)
input_data = torch.randn(batch_size, seq_length, input_size)
output = model(input_data)

print(f"Output shape: {output.shape}")

4. Summary

  • LSTMs introduce a Cell State and Gates (Forget, Input, Output) to regulate information flow.
  • The Forget Gate allows the network to reset memory when it’s no longer relevant.
  • GRUs are a simplified version of LSTMs, merging the cell and hidden states.
  • Both are standard tools for handling long sequences in deep learning.

Next, we’ll look at Seq2Seq Models, which use two RNNs (Encoder and Decoder) to translate sequences.