Batch Normalization: Stabilizing Training

Training deep neural networks is hard because the distribution of each layer’s inputs changes during training, as the parameters of the previous layers change. This phenomenon is called Internal Covariate Shift.

Batch Normalization (BatchNorm) is a technique to address this by normalizing the layer inputs. It forces the activations to have a mean of 0 and a variance of 1.

1. The Problem: Shifting Sands

Imagine trying to learn to juggle, but the balls keep changing weight and size every few seconds. That’s what a hidden layer feels like without BatchNorm. It constantly has to adapt to a new distribution of inputs from the previous layer.

2. The Solution: Normalize

For each mini-batch, we calculate the mean (μB) and variance (σB2). Then we normalize the input x:

x_hat = (x - μB) / √(σB2 + ε)

Then we scale and shift it using learnable parameters γ (scale) and β (shift):

y = γ x_hat + β
  • γ and β allow the network to undo the normalization if it decides that’s better for learning (e.g., if it needs a non-zero mean).

3. Interactive: Distribution Visualizer

See how the distribution of activations behaves over time (epochs) with and without Batch Normalization.

Without BatchNorm

Mean: 0.0, Std: 1.0

With BatchNorm

Mean: 0.0, Std: 1.0

4. Benefits

  1. Faster Training: You can use higher learning rates.
  2. Regularization: It adds a slight noise to the training, reducing overfitting (acting like Dropout).
  3. Less Sensitive to Initialization: You don’t need to be as careful with weight initialization.

5. Implementation Details

Training vs. Inference

  • Training: Calculate mean/variance from the current batch.
  • Inference: Use the running average of mean/variance collected during training.

Python Code (NumPy)

import numpy as np

class BatchNorm:
  def __init__(self, num_features, momentum=0.9, epsilon=1e-5):
    self.momentum = momentum
    self.epsilon = epsilon

    # Learnable parameters
    self.gamma = np.ones((1, num_features))
    self.beta = np.zeros((1, num_features))

    # Running stats for inference
    self.running_mean = np.zeros((1, num_features))
    self.running_var = np.ones((1, num_features))

  def forward(self, x, mode='train'):
    if mode == 'train':
      # 1. Calculate mean and variance of batch
      batch_mean = np.mean(x, axis=0, keepdims=True)
      batch_var = np.var(x, axis=0, keepdims=True)

      # 2. Update running stats
      self.running_mean = self.momentum * self.running_mean + (1 - self.momentum) * batch_mean
      self.running_var = self.momentum * self.running_var + (1 - self.momentum) * batch_var

      # 3. Normalize
      self.x_centered = x - batch_mean
      self.std_inv = 1. / np.sqrt(batch_var + self.epsilon)
      self.x_norm = self.x_centered * self.std_inv

      # 4. Scale and Shift
      out = self.gamma * self.x_norm + self.beta
      return out

    elif mode == 'test':
      # Use running stats
      x_norm = (x - self.running_mean) / np.sqrt(self.running_var + self.epsilon)
      out = self.gamma * x_norm + self.beta
      return out

[!TIP] Placement: There is an ongoing debate about whether to place BatchNorm before or after the activation function. The original paper put it before (Conv → BN → ReLU), but many modern architectures (like ResNet v2) put it after. Both usually work well.