Batch Normalization: Stabilizing Training
Training deep neural networks is hard because the distribution of each layer’s inputs changes during training, as the parameters of the previous layers change. This phenomenon is called Internal Covariate Shift.
Batch Normalization (BatchNorm) is a technique to address this by normalizing the layer inputs. It forces the activations to have a mean of 0 and a variance of 1.
1. The Problem: Shifting Sands
Imagine trying to learn to juggle, but the balls keep changing weight and size every few seconds. That’s what a hidden layer feels like without BatchNorm. It constantly has to adapt to a new distribution of inputs from the previous layer.
2. The Solution: Normalize
For each mini-batch, we calculate the mean (μB) and variance (σB2). Then we normalize the input x:
Then we scale and shift it using learnable parameters γ (scale) and β (shift):
- γ and β allow the network to undo the normalization if it decides that’s better for learning (e.g., if it needs a non-zero mean).
3. Interactive: Distribution Visualizer
See how the distribution of activations behaves over time (epochs) with and without Batch Normalization.
Without BatchNorm
With BatchNorm
4. Benefits
- Faster Training: You can use higher learning rates.
- Regularization: It adds a slight noise to the training, reducing overfitting (acting like Dropout).
- Less Sensitive to Initialization: You don’t need to be as careful with weight initialization.
5. Implementation Details
Training vs. Inference
- Training: Calculate mean/variance from the current batch.
- Inference: Use the running average of mean/variance collected during training.
Python Code (NumPy)
import numpy as np
class BatchNorm:
def __init__(self, num_features, momentum=0.9, epsilon=1e-5):
self.momentum = momentum
self.epsilon = epsilon
# Learnable parameters
self.gamma = np.ones((1, num_features))
self.beta = np.zeros((1, num_features))
# Running stats for inference
self.running_mean = np.zeros((1, num_features))
self.running_var = np.ones((1, num_features))
def forward(self, x, mode='train'):
if mode == 'train':
# 1. Calculate mean and variance of batch
batch_mean = np.mean(x, axis=0, keepdims=True)
batch_var = np.var(x, axis=0, keepdims=True)
# 2. Update running stats
self.running_mean = self.momentum * self.running_mean + (1 - self.momentum) * batch_mean
self.running_var = self.momentum * self.running_var + (1 - self.momentum) * batch_var
# 3. Normalize
self.x_centered = x - batch_mean
self.std_inv = 1. / np.sqrt(batch_var + self.epsilon)
self.x_norm = self.x_centered * self.std_inv
# 4. Scale and Shift
out = self.gamma * self.x_norm + self.beta
return out
elif mode == 'test':
# Use running stats
x_norm = (x - self.running_mean) / np.sqrt(self.running_var + self.epsilon)
out = self.gamma * x_norm + self.beta
return out
[!TIP] Placement: There is an ongoing debate about whether to place BatchNorm before or after the activation function. The original paper put it before (Conv → BN → ReLU), but many modern architectures (like ResNet v2) put it after. Both usually work well.