Transfer Learning: Standing on Giants
Training a Convolutional Neural Network from scratch requires massive datasets (millions of images) and weeks of GPU time. Most of us don’t have that.
Transfer Learning is the technique of taking a model trained on a large task (like ImageNet with 1.2M images) and adapting it to a smaller, specific task (like distinguishing hot dogs from not hot dogs).
[!TIP] In practice, you will almost never train a CNN from scratch. You will always initialize with weights trained on ImageNet.
1. How it Works
Deep CNNs learn hierarchical features:
- Early Layers: Detect edges, textures, curves (Generic features).
- Middle Layers: Detect parts (eyes, wheels, leaves).
- Later Layers: Detect whole objects (faces, cars).
The early features are universal. An edge is an edge, whether it’s in a photo of a cat or an X-ray. Transfer Learning reuses these generic features and only retrains the later layers for your specific task.
2. Strategies: Feature Extraction vs. Fine-Tuning
There are two main approaches:
A. Feature Extraction
We treat the pretrained network as a fixed feature extractor.
- Freeze all the convolutional layers (weights don’t update).
- Replace the final Fully Connected (FC) layer with a new one matching your number of classes.
- Train only the new FC layer.
B. Fine-Tuning
We unfreeze some (or all) of the top convolutional layers and train them along with the classifier.
- Freeze the bottom layers (generic features).
- Unfreeze the top layers (specific features).
- Train with a very low learning rate to avoid wrecking the pretrained weights.
3. Interactive Fine-Tuning Simulator
Visualize how gradients flow during backpropagation depending on which layers are frozen. Toggle the layers to Frozen (Blue) or Trainable (Green). Click “Train Step” to see the gradient flow (Red dots).
4. PyTorch Implementation
Here is how to perform Feature Extraction using a pretrained ResNet18.
import torch
import torch.nn as nn
import torchvision.models as models
# 1. Load the pretrained model
# 'pretrained=True' is deprecated, use 'weights'
weights = models.ResNet18_Weights.DEFAULT
model = models.resnet18(weights=weights)
# 2. Freeze all parameters (Feature Extraction)
for param in model.parameters():
param.requires_grad = False
# 3. Replace the Classifier
# ResNet's final layer is called 'fc'. We replace it.
# The original fc input features is model.fc.in_features (512 for ResNet18)
num_classes = 2 # e.g., Cat vs Dog
model.fc = nn.Linear(model.fc.in_features, num_classes)
# Note: The new layer (model.fc) has requires_grad=True by default!
# 4. Check trainable parameters
for name, param in model.named_parameters():
if param.requires_grad:
print(f"Trainable: {name}")
# Output:
# Trainable: fc.weight
# Trainable: fc.bias
When to Fine-Tune?
If you have a lot of data, you might want to unfreeze more layers.
# Unfreeze the last block (layer4)
for param in model.layer4.parameters():
param.requires_grad = True
# Now gradients will flow through fc -> layer4 -> stop at layer3
5. Summary
- Transfer Learning allows us to train high-accuracy models with small datasets.
- Feature Extraction: Freeze the base, train only the head. Best for small datasets similar to ImageNet.
- Fine-Tuning: Unfreeze some base layers. Best for larger datasets or different domains (e.g., Medical images).
- PyTorch: Use
param.requires_grad = Falseto freeze layers.
In the next section, we will review everything we’ve learned in this module.