Automatic Differentiation: The Magic of PyTorch

[!NOTE] This module explores the core principles of Automatic Differentiation: The Magic of PyTorch, deriving solutions from first principles and hardware constraints to build world-class, production-ready expertise.

1. Introduction: Who computes the gradients?

In calculus class, you calculated derivatives by hand. In early AI (80s), researchers derived gradients on paper, simplified them, and coded them in C++. In Modern AI (PyTorch/TF), you write the Forward Pass, and the framework calculates the Backward Pass (gradients) automatically. This is Automatic Differentiation (AutoDiff).

It is NOT numerical differentiation (finite differences, which is slow and imprecise). It is NOT symbolic differentiation (like Mathematica, which can lead to expression explosion). It is the exact application of the Chain Rule on a graph structure.


2. The Computational Graph

Every calculation in your code builds a Directed Acyclic Graph (DAG).

  • Nodes: Operations (Addition, Multiplication, Sin, Exp).
  • Edges: Tensors (Data flowing between operations).

Example: y = (x + w) × b

  1. Input Nodes: x, w, b.
  2. Intermediate Node: a = x + w.
  3. Output Node: y = a \times b.

To find \frac{\partial y}{\partial x}, we just traverse the graph backwards!


3. Forward vs Reverse Mode

Why do we always talk about “Backpropagation”? Why not “Forwardpropagation” of gradients?

Forward Mode

Computes \frac{\partial v}{\partial x} for every node v as we go forward.

  • Mechanism: We track the value and the derivative w.r.t. one input simultaneously.
  • Best for: Functions with few inputs and many outputs.
  • Cost: Proportional to the number of inputs.

Reverse Mode (Backpropagation)

Computes y first, then goes backward to find \frac{\partial y}{\partial x}, \frac{\partial y}{\partial w}, \dots

  • Mechanism: We compute the output, then propagate the error signal backwards.
  • Best for: Functions with many inputs (billion weights) and few outputs (1 loss value).
  • Cost: Proportional to the number of outputs.
  • Winner: Deep Learning! We typically have 1 Loss value and 100B parameters. Reverse mode gives us all 100B gradients in one pass.

4. Python: Autograd from Scratch

Let’s build a tiny AutoDiff engine (inspired by Andrej Karpathy’s Micrograd) to understand what PyTorch does under the hood.

import math

class Value:
    def __init__(self, data, _children=(), _op='', label=''):
        self.data = data
        self.grad = 0.0
        self._backward = lambda: None
        self._prev = set(_children)
        self._op = _op
        self.label = label

    def __repr__(self):
        return f"Value(data={self.data})"

    def __add__(self, other):
        other = other if isinstance(other, Value) else Value(other)
        out = Value(self.data + other.data, (self, other), '+')

        def _backward():
            # Local gradients for addition are 1.0
            self.grad += 1.0 * out.grad
            other.grad += 1.0 * out.grad
        out._backward = _backward
        return out

    def __mul__(self, other):
        other = other if isinstance(other, Value) else Value(other)
        out = Value(self.data * other.data, (self, other), '*')

        def _backward():
            # Local gradients for multiplication are the other value
            self.grad += other.data * out.grad
            other.grad += self.data * out.grad
        out._backward = _backward
        return out

    def backward(self):
        # Topological sort to ensure we process parents before children
        topo = []
        visited = set()
        def build_topo(v):
            if v not in visited:
                visited.add(v)
                for child in v._prev:
                    build_topo(child)
                topo.append(v)
        build_topo(self)

        # Go backwards
        self.grad = 1.0
        for node in reversed(topo):
            node._backward()

# Usage
x = Value(2.0, label='x')
w = Value(-3.0, label='w')
b = Value(10.0, label='b')
a = x * w; a.label = 'a'
y = a + b; y.label = 'y'

y.backward()
print(f"y = {y.data}")
print(f"dy/dx = {x.grad}")

5. Interactive Visualizer: Graph Builder

Visualize the Computational Graph for y = (x + w) × b.

  • Inputs: x=2, w=1, b=3.
  • Forward Pass (Blue): Values flow up. 2+1 = 3, then 3 \times 3 = 9.
  • Backward Pass (Red): Gradients flow down. \frac{\partial y}{\partial y} = 1, then splits to a and b.

6. Summary

  • Computational Graph: Represents math as a tree of operations.
  • AutoDiff: Applies the Chain Rule recursively on this graph.
  • Reverse Mode: The secret sauce of Deep Learning that allows us to compute gradients for millions of parameters efficiently in one backward pass.