Joint and Marginal Distributions

In the real world, variables rarely exist in isolation. Your height and weight are correlated. The temperature and humidity are linked. To model these relationships, we need Multivariate Probability Distributions.

This chapter explores how two random variables, X and Y, behave together (Joint) and how we can recover their individual behaviors (Marginal).

1. The Joint Probability Distribution

The Joint Probability Distribution P(X, Y) gives the probability that two events happen simultaneously. It is the “God’s Eye View” of the system, containing all information about the relationship between variables.

For discrete variables, we write this as the Joint Probability Mass Function (PMF):

P(X = x, Y = y) = P(x, y)

Where:

  • 0 ≤ P(x, y) ≤ 1 for all pairs (x, y).
  • The sum of all probabilities must be exactly 1: Σx Σy P(x, y) = 1.

[!NOTE] Think of the Joint Distribution as a topographical map. The pairs (x, y) are coordinates (latitude and longitude), and the probability P(x, y) is the elevation at that point.

Visualizing the Joint Distribution

Imagine we have two variables:

  1. X: The number of hours a student studies (0, 1, 2).
  2. Y: The grade they receive (C, B, A).

The joint distribution is a table (or matrix) where each cell represents a specific outcome, like “Studied 2 hours AND got an A”.


2. Marginalization: Recovering Individual Behavior

What if we only care about the grade Y, regardless of how long the student studied? We need to marginalize out the variable X.

The Marginal Probability P(Y) is found by summing the joint probabilities over all possible values of X:

P(Y = y) = Σx P(X = x, Y = y)

This is called the Sum Rule of probability.

  • To get the marginal distribution of Y (columns), sum across the rows (X).
  • To get the marginal distribution of X (rows), sum down the columns (Y).

[!TIP] The term “Marginal” comes from accounting. In a ledger table, the sums of rows and columns were written in the margins of the page.


3. Interactive: The Marginalizer 3000

Explore the relationship between the joint grid and the marginals.

  • Grid: Represents the Joint Distribution P(X, Y). Click a cell to increase its probability.
  • Right Bar: The Marginal Distribution of Rows P(X).
  • Bottom Bar: The Marginal Distribution of Columns P(Y).
Variable X (Rows) →
P(X)
P(Y)
Total Sum: 1.00
Entropy: -- bits

4. Hardware Reality: The Curse of Dimensionality

Why do we care so much about marginalizing and factoring distributions? Why not just store the big table P(X, Y, Z, …)?

The Exponential Explosion

If we have N binary variables, the joint distribution table has 2N entries.

  • 2 Variables: 22 = 4 entries (Easy)
  • 10 Variables: 210 = 1024 entries (KB)
  • 30 Variables: 230 \approx 1 Billion entries (GB)
  • 100 Variables: 2100 \approx 1.2 \times 1030 entries.

Hardware Fact: There are only about 1080 atoms in the observable universe. Even if we could store one probability per atom, we couldn’t store the joint distribution for just 300 binary variables.

Memory Locality

Even for smaller N, accessing a giant multidimensional array is a cache nightmare.

  • Row-Major Order: In C/C++/Java, matrices are stored row by row. Summing across rows (marginalizing Y) is fast because memory access is sequential (spatial locality).
  • Column Access: Summing down columns (marginalizing X) is slow because you jump across memory addresses, causing cache misses.

This is why we need Probabilistic Graphical Models (PGMs) like Bayesian Networks. They allow us to represent the joint distribution as a product of smaller conditional distributions, saving exponential space.


5. Continuous Case: Integrals

For continuous variables, we replace sums with integrals. The probability is defined by a Joint Probability Density Function (PDF) f(x, y).

Definition

The volume under the entire surface must be 1:

∫ ∫ f(x, y) dx dy = 1

Marginalization

To find the marginal PDF of X, we “integrate out” Y:

f_X(x) = ∫ f(x, y) dy

Imagine shining a flashlight through a translucent 3D object (the joint PDF) onto a wall. The shadow on the wall is the marginal distribution.


6. Implementing in Code

We represent discrete joint distributions as 2D arrays (matrices). Below are implementations in Java, Go, and Python.

Java Example

import java.util.Arrays;

public class JointMarginal {
    public static void main(String[] args) {
        // 1. Define a Joint Distribution P(X, Y)
        // Rows = X (0, 1, 2), Cols = Y (0, 1, 2)
        double[][] jointProb = {
            {0.10, 0.05, 0.00}, // X=0
            {0.05, 0.20, 0.10}, // X=1
            {0.00, 0.10, 0.40}  // X=2
        };

        // 2. Marginalize out Y to get P(X)
        // Sum across columns (rows remain)
        double[] marginalX = new double[jointProb.length];
        for (int i = 0; i < jointProb.length; i++) {
            double sum = 0;
            for (int j = 0; j < jointProb[i].length; j++) {
                sum += jointProb[i][j];
            }
            marginalX[i] = sum;
        }

        // 3. Marginalize out X to get P(Y)
        // Sum across rows (columns remain)
        double[] marginalY = new double[jointProb[0].length];
        for (int j = 0; j < jointProb[0].length; j++) {
            double sum = 0;
            for (int i = 0; i < jointProb.length; i++) {
                sum += jointProb[i][j];
            }
            marginalY[j] = sum;
        }

        System.out.println("P(X): " + Arrays.toString(marginalX));
        System.out.println("P(Y): " + Arrays.toString(marginalY));
    }
}

Go Example

package main

import (
	"fmt"
)

func main() {
	// 1. Define Joint Distribution P(X, Y)
	// Flattened grid: 3x3
	jointProb := [][]float64{
		{0.10, 0.05, 0.00}, // X=0
		{0.05, 0.20, 0.10}, // X=1
		{0.00, 0.10, 0.40}, // X=2
	}

	// 2. Marginalize out Y to get P(X)
	// Sum each row
	marginalX := make([]float64, 3)
	for i, row := range jointProb {
		sum := 0.0
		for _, val := range row {
			sum += val
		}
		marginalX[i] = sum
	}

	// 3. Marginalize out X to get P(Y)
	// Sum each column
	marginalY := make([]float64, 3)
	for col := 0; col < 3; col++ {
		sum := 0.0
		for row := 0; row < 3; row++ {
			sum += jointProb[row][col]
		}
		marginalY[col] = sum
	}

	fmt.Printf("P(X): %.2f\n", marginalX)
	fmt.Printf("P(Y): %.2f\n", marginalY)
}

Python (NumPy) Example

import numpy as np

# 1. Define a Joint Distribution P(X, Y)
# Rows = X (0, 1, 2), Cols = Y (0, 1, 2)
joint_prob = np.array([
    [0.10, 0.05, 0.00],  # X=0
    [0.05, 0.20, 0.10],  # X=1
    [0.00, 0.10, 0.40]   # X=2
])

# 2. Marginalize out Y to get P(X)
# Sum across columns (axis=1)
marginal_x = np.sum(joint_prob, axis=1)
print(f"P(X): {marginal_x}")
# Output: [0.15 0.35 0.50]

# 3. Marginalize out X to get P(Y)
# Sum across rows (axis=0)
marginal_y = np.sum(joint_prob, axis=0)
print(f"P(Y): {marginal_y}")
# Output: [0.15 0.35 0.50]

7. Summary

  • Joint Distribution: P(X, Y) is the complete map of reality, but it suffers from the Curse of Dimensionality (2N entries).
  • Marginal Distribution: P(X) is the “shadow” of the joint distribution on a single axis.
  • Marginalization: The process of summing out unwanted variables.
  • Hardware: Efficient marginalization respects memory locality (row-major access).

Next: Conditional Independence