Cross-Entropy Loss
In Deep Learning, we train models by minimizing a Loss Function. For classification tasks (predicting categories), the standard choice is Cross-Entropy Loss (also known as Log Loss).
1. Definition
Cross-Entropy measures the difference between two probability distributions P (the true distribution) and Q (the predicted distribution).
- P(x): The true label distribution. usually a one-hot vector (e.g.,
[0, 1, 0]for class 2). - Q(x): The model’s prediction, usually output from a Softmax function (e.g.,
[0.1, 0.8, 0.1]).
2. Relation to KL Divergence
Cross-Entropy is directly related to Entropy and KL Divergence:
- H(P): The entropy of the true distribution. If labels are one-hot (certainty), H(P) = 0.
-
**DKL(P Q)**: The divergence between truth and prediction.
[!TIP] Why minimize Cross-Entropy? Since H(P) is constant for a fixed dataset, minimizing Cross-Entropy is mathematically equivalent to minimizing KL Divergence. We are trying to make our predicted distribution Q as close as possible to the true distribution P.
3. Binary Cross-Entropy (BCE)
For binary classification (0 or 1), the formula simplifies. Let y be the true label (0 or 1) and ŷ be the predicted probability of class 1.
- If y=1: Loss is −log(ŷ). We want ŷ → 1.
- If y=0: Loss is −log(1−ŷ). We want ŷ → 0.
Interactive: The Penalty of Confidence
Explore how the loss changes based on the model’s prediction confidence. We assume the True Label is 1 (Positive).
- Move the slider to change the Predicted Probability (ŷ).
- Observe how the Loss explodes as the prediction gets closer to 0 (confident but wrong).
4. Hardware Reality: Log-Sum-Exp Trick
When computing Cross-Entropy in code (like Softmax), we often encounter numerical instability. Softmax involves exponentials: ex. If x is large (e.g., 100), e100 is huge and can cause floating point overflow.
To solve this, hardware implementations use the Log-Sum-Exp (LSE) trick.
Instead of computing log(sum(exp(x))) directly, we use the identity:
By choosing a = max(x), we ensure the largest exponent is 0 (since e0 = 1), preventing overflow. This is why PyTorch’s CrossEntropyLoss takes logits (raw scores) instead of probabilities—it applies this optimization internally.
5. Code Implementation
Calculate Cross-Entropy Loss in Java, Go, and Python.
Java
import java.util.Arrays;
public class CrossEntropy {
public static double crossEntropyLoss(double[] predictedProbs, int trueClassIndex) {
// Avoid log(0)
double p = Math.max(predictedProbs[trueClassIndex], 1e-15);
return -Math.log(p); // Natural log (nats)
}
// Stable Softmax using max subtraction
public static double[] softmax(double[] logits) {
double maxLogit = Double.NEGATIVE_INFINITY;
for (double val : logits) maxLogit = Math.max(maxLogit, val);
double sumExp = 0.0;
double[] probs = new double[logits.length];
for (int i = 0; i < logits.length; i++) {
probs[i] = Math.exp(logits[i] - maxLogit);
sumExp += probs[i];
}
for (int i = 0; i < probs.length; i++) {
probs[i] /= sumExp;
}
return probs;
}
public static void main(String[] args) {
double[] logits = {2.0, 1.0, 0.1};
int trueClass = 0; // The first class is the correct one
double[] probs = softmax(logits);
double loss = crossEntropyLoss(probs, trueClass);
System.out.println("Probabilities: " + Arrays.toString(probs));
System.out.printf("Loss: %.4f%n", loss);
}
}
Go
package main
import (
"fmt"
"math"
)
// CrossEntropyLoss calculates loss for a single sample
func CrossEntropyLoss(predictedProbs []float64, trueClassIndex int) float64 {
p := predictedProbs[trueClassIndex]
if p < 1e-15 {
p = 1e-15
}
return -math.Log(p) // Natural log
}
// Softmax with stability fix
func Softmax(logits []float64) []float64 {
maxLogit := -math.MaxFloat64
for _, val := range logits {
if val > maxLogit {
maxLogit = val
}
}
sumExp := 0.0
probs := make([]float64, len(logits))
for i, val := range logits {
probs[i] = math.Exp(val - maxLogit)
sumExp += probs[i]
}
for i := range probs {
probs[i] /= sumExp
}
return probs
}
func main() {
logits := []float64{2.0, 1.0, 0.1}
trueClass := 0
probs := Softmax(logits)
loss := CrossEntropyLoss(probs, trueClass)
fmt.Printf("Probabilities: %.4f\n", probs)
fmt.Printf("Loss: %.4f\n", loss)
}
Python
import numpy as np
import torch
import torch.nn as nn
# 1. NumPy Implementation (Manual)
def cross_entropy_numpy(predictions, targets):
"""
predictions: (N, C) array of probabilities (after softmax)
targets: (N,) array of true class indices
"""
N = predictions.shape[0]
# Select the probability of the true class for each sample
# predictions[range(N), targets] gets the prob of true class
log_likelihood = -np.log(predictions[range(N), targets])
loss = np.sum(log_likelihood) / N
return loss
# Example Data
# 3 samples, 3 classes
y_pred_probs = np.array([
[0.7, 0.2, 0.1], # Correct class 0
[0.1, 0.8, 0.1], # Correct class 1
[0.2, 0.2, 0.6] # Correct class 2
])
y_true = np.array([0, 1, 2])
loss_np = cross_entropy_numpy(y_pred_probs, y_true)
print(f"NumPy Loss: {loss_np:.4f}")
# 2. PyTorch Implementation
# Note: nn.CrossEntropyLoss expects raw logits, NOT probabilities (it applies Softmax internally)
loss_fn = nn.CrossEntropyLoss()
# Logits (before softmax)
logits = torch.tensor([
[2.0, 1.0, 0.1],
[0.5, 3.0, 0.2],
[0.1, 0.2, 1.5]
])
targets = torch.tensor([0, 1, 2])
loss_torch = loss_fn(logits, targets)
print(f"PyTorch Loss: {loss_torch.item():.4f}")
Key Takeaways
- Differentiable: Unlike “Accuracy” (which is discrete), Cross-Entropy is smooth and differentiable, allowing Gradient Descent to work.
- Penalizes Confidence: Being confidently wrong (e.g., predicting 0.99 probability for the wrong class) incurs a massive loss penalty.
- Softmax + CE: In PyTorch,
CrossEntropyLosscombinesLogSoftmaxandNLLLossfor numerical stability.