Cross-Entropy Loss

In Deep Learning, we train models by minimizing a Loss Function. For classification tasks (predicting categories), the standard choice is Cross-Entropy Loss (also known as Log Loss).

1. Definition

Cross-Entropy measures the difference between two probability distributions P (the true distribution) and Q (the predicted distribution).

H(P, Q) = − Σx P(x) log Q(x)
  • P(x): The true label distribution. usually a one-hot vector (e.g., [0, 1, 0] for class 2).
  • Q(x): The model’s prediction, usually output from a Softmax function (e.g., [0.1, 0.8, 0.1]).

2. Relation to KL Divergence

Cross-Entropy is directly related to Entropy and KL Divergence:

H(P, Q) = H(P) + D_{KL}(P || Q)
  • H(P): The entropy of the true distribution. If labels are one-hot (certainty), H(P) = 0.
  • **DKL(P   Q)**: The divergence between truth and prediction.

[!TIP] Why minimize Cross-Entropy? Since H(P) is constant for a fixed dataset, minimizing Cross-Entropy is mathematically equivalent to minimizing KL Divergence. We are trying to make our predicted distribution Q as close as possible to the true distribution P.

3. Binary Cross-Entropy (BCE)

For binary classification (0 or 1), the formula simplifies. Let y be the true label (0 or 1) and ŷ be the predicted probability of class 1.

Loss = − [y log(ŷ) + (1−y) log(1−ŷ)]
  • If y=1: Loss is −log(ŷ). We want ŷ → 1.
  • If y=0: Loss is −log(1−ŷ). We want ŷ → 0.

Interactive: The Penalty of Confidence

Explore how the loss changes based on the model’s prediction confidence. We assume the True Label is 1 (Positive).

  • Move the slider to change the Predicted Probability (ŷ).
  • Observe how the Loss explodes as the prediction gets closer to 0 (confident but wrong).
Prediction 0.50
Cross-Entropy Loss 0.69

4. Hardware Reality: Log-Sum-Exp Trick

When computing Cross-Entropy in code (like Softmax), we often encounter numerical instability. Softmax involves exponentials: ex. If x is large (e.g., 100), e100 is huge and can cause floating point overflow.

To solve this, hardware implementations use the Log-Sum-Exp (LSE) trick. Instead of computing log(sum(exp(x))) directly, we use the identity:

log(Σ exi) = a + log(Σ exi - a)

By choosing a = max(x), we ensure the largest exponent is 0 (since e0 = 1), preventing overflow. This is why PyTorch’s CrossEntropyLoss takes logits (raw scores) instead of probabilities—it applies this optimization internally.

5. Code Implementation

Calculate Cross-Entropy Loss in Java, Go, and Python.

Java

import java.util.Arrays;

public class CrossEntropy {

    public static double crossEntropyLoss(double[] predictedProbs, int trueClassIndex) {
        // Avoid log(0)
        double p = Math.max(predictedProbs[trueClassIndex], 1e-15);
        return -Math.log(p); // Natural log (nats)
    }

    // Stable Softmax using max subtraction
    public static double[] softmax(double[] logits) {
        double maxLogit = Double.NEGATIVE_INFINITY;
        for (double val : logits) maxLogit = Math.max(maxLogit, val);

        double sumExp = 0.0;
        double[] probs = new double[logits.length];
        for (int i = 0; i < logits.length; i++) {
            probs[i] = Math.exp(logits[i] - maxLogit);
            sumExp += probs[i];
        }

        for (int i = 0; i < probs.length; i++) {
            probs[i] /= sumExp;
        }
        return probs;
    }

    public static void main(String[] args) {
        double[] logits = {2.0, 1.0, 0.1};
        int trueClass = 0; // The first class is the correct one

        double[] probs = softmax(logits);
        double loss = crossEntropyLoss(probs, trueClass);

        System.out.println("Probabilities: " + Arrays.toString(probs));
        System.out.printf("Loss: %.4f%n", loss);
    }
}

Go

package main

import (
	"fmt"
	"math"
)

// CrossEntropyLoss calculates loss for a single sample
func CrossEntropyLoss(predictedProbs []float64, trueClassIndex int) float64 {
	p := predictedProbs[trueClassIndex]
	if p < 1e-15 {
		p = 1e-15
	}
	return -math.Log(p) // Natural log
}

// Softmax with stability fix
func Softmax(logits []float64) []float64 {
	maxLogit := -math.MaxFloat64
	for _, val := range logits {
		if val > maxLogit {
			maxLogit = val
		}
	}

	sumExp := 0.0
	probs := make([]float64, len(logits))
	for i, val := range logits {
		probs[i] = math.Exp(val - maxLogit)
		sumExp += probs[i]
	}

	for i := range probs {
		probs[i] /= sumExp
	}
	return probs
}

func main() {
	logits := []float64{2.0, 1.0, 0.1}
	trueClass := 0

	probs := Softmax(logits)
	loss := CrossEntropyLoss(probs, trueClass)

	fmt.Printf("Probabilities: %.4f\n", probs)
	fmt.Printf("Loss: %.4f\n", loss)
}

Python

import numpy as np
import torch
import torch.nn as nn

# 1. NumPy Implementation (Manual)
def cross_entropy_numpy(predictions, targets):
    """
    predictions: (N, C) array of probabilities (after softmax)
    targets: (N,) array of true class indices
    """
    N = predictions.shape[0]
    # Select the probability of the true class for each sample
    # predictions[range(N), targets] gets the prob of true class
    log_likelihood = -np.log(predictions[range(N), targets])
    loss = np.sum(log_likelihood) / N
    return loss

# Example Data
# 3 samples, 3 classes
y_pred_probs = np.array([
    [0.7, 0.2, 0.1], # Correct class 0
    [0.1, 0.8, 0.1], # Correct class 1
    [0.2, 0.2, 0.6]  # Correct class 2
])
y_true = np.array([0, 1, 2])

loss_np = cross_entropy_numpy(y_pred_probs, y_true)
print(f"NumPy Loss: {loss_np:.4f}")

# 2. PyTorch Implementation
# Note: nn.CrossEntropyLoss expects raw logits, NOT probabilities (it applies Softmax internally)
loss_fn = nn.CrossEntropyLoss()

# Logits (before softmax)
logits = torch.tensor([
    [2.0, 1.0, 0.1],
    [0.5, 3.0, 0.2],
    [0.1, 0.2, 1.5]
])
targets = torch.tensor([0, 1, 2])

loss_torch = loss_fn(logits, targets)
print(f"PyTorch Loss: {loss_torch.item():.4f}")

Key Takeaways

  1. Differentiable: Unlike “Accuracy” (which is discrete), Cross-Entropy is smooth and differentiable, allowing Gradient Descent to work.
  2. Penalizes Confidence: Being confidently wrong (e.g., predicting 0.99 probability for the wrong class) incurs a massive loss penalty.
  3. Softmax + CE: In PyTorch, CrossEntropyLoss combines LogSoftmax and NLLLoss for numerical stability.