PyTorch
KL Divergence
Probability Distributions
Machine Learning
Deep Learning

KL Divergence for two probability distributions in PyTorch

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

KL divergence measures how one probability distribution diverges from another. In PyTorch, there are two main ways to compute it: use torch.nn.functional.kl_div when you already have tensors of probabilities or log-probabilities, or use torch.distributions.kl_divergence when you are working with distribution objects such as Normal or Categorical.

Know the Formula and the Direction

For discrete distributions, the KL divergence from P to Q is:

  • sum over P(x) * log(P(x) / Q(x))

The order matters. KL(P || Q) is not the same as KL(Q || P). This is one of the most common mistakes in code because the tensors may look symmetric even though the math is not.

Use torch.nn.functional.kl_div Correctly

PyTorch's functional API expects the first argument to be log-probabilities and the target to be probabilities by default.

python
1import torch
2import torch.nn.functional as F
3
4p = torch.tensor([0.7, 0.2, 0.1], dtype=torch.float32)
5q = torch.tensor([0.6, 0.3, 0.1], dtype=torch.float32)
6
7log_q = q.log()
8kl = F.kl_div(log_q, p, reduction="sum")
9print(kl.item())

This computes KL(P || Q) even though log_q appears first, because the target p is treated as the reference distribution. That argument order surprises a lot of people, so it is worth re-reading the API contract every time.

If you want batch averaging:

python
kl = F.kl_div(log_q, p, reduction="batchmean")

Choose the reduction deliberately because sum, mean, and batchmean can produce very different scales.

Compute It Manually for Clarity

When debugging or teaching, a manual calculation is often easier to reason about.

python
1import torch
2
3p = torch.tensor([0.7, 0.2, 0.1], dtype=torch.float32)
4q = torch.tensor([0.6, 0.3, 0.1], dtype=torch.float32)
5
6kl_manual = torch.sum(p * torch.log(p / q))
7print(kl_manual.item())

This makes the direction explicit and helps confirm whether a library call is being used correctly.

In practical code, clamp or smooth zero probabilities if needed, because log(0) is undefined.

Use Distribution Objects When You Have Actual Distributions

If you are working with probability distribution classes, use torch.distributions.kl_divergence.

python
1import torch
2from torch.distributions import Normal, kl_divergence
3
4p = Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0))
5q = Normal(loc=torch.tensor(1.0), scale=torch.tensor(1.5))
6
7kl = kl_divergence(p, q)
8print(kl.item())

This is the right choice for VAEs and other probabilistic models where the distributions are parameterized objects rather than plain categorical tensors.

Watch for Zeroes and Log Space

KL divergence becomes numerically fragile when probabilities approach zero. Practical safeguards include:

  • normalizing tensors so they sum to one
  • clamping values away from zero before logs
  • using log-softmax outputs when appropriate

For model outputs, a common pattern is:

python
log_probs = torch.log_softmax(logits, dim=-1)
target_probs = torch.softmax(target_logits, dim=-1)
kl = F.kl_div(log_probs, target_probs, reduction="batchmean")

This is common in distillation and variational objectives because it is more stable than manually taking logs of already small probabilities.

Normalize Before Comparing Arbitrary Tensors

If your starting tensors are scores or counts rather than proper distributions, normalize them first. KL divergence assumes valid probability distributions, so the elements should be non-negative and sum to one across the comparison dimension. If that assumption is broken, the numeric result may still exist, but it no longer has the interpretation you think it has.

Common Pitfalls

  • Reversing P and Q and computing the wrong KL direction.
  • Passing raw probabilities as the first argument to F.kl_div instead of log-probabilities.
  • Ignoring reduction choice and comparing losses with incompatible scales.
  • Taking logs of zero or nearly zero probabilities without stabilization.
  • Using tensor-based KL code when distribution objects would be clearer and safer.

Summary

  • In PyTorch, KL divergence can be computed with tensor APIs or with distribution objects.
  • 'F.kl_div expects log-probabilities as input and usually probabilities as the target.'
  • The direction of KL matters and is easy to reverse by mistake.
  • Manual computation is useful for validation and debugging.
  • Stabilize probabilities and choose reduction carefully to avoid misleading results.

Course illustration
Course illustration

All Rights Reserved.