KL Divergence for two probability distributions in PyTorch
Master System Design with Codemia
Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.
Introduction
KL divergence measures how one probability distribution diverges from another. In PyTorch, there are two main ways to compute it: use torch.nn.functional.kl_div when you already have tensors of probabilities or log-probabilities, or use torch.distributions.kl_divergence when you are working with distribution objects such as Normal or Categorical.
Know the Formula and the Direction
For discrete distributions, the KL divergence from P to Q is:
- sum over
P(x) * log(P(x) / Q(x))
The order matters. KL(P || Q) is not the same as KL(Q || P). This is one of the most common mistakes in code because the tensors may look symmetric even though the math is not.
Use torch.nn.functional.kl_div Correctly
PyTorch's functional API expects the first argument to be log-probabilities and the target to be probabilities by default.
This computes KL(P || Q) even though log_q appears first, because the target p is treated as the reference distribution. That argument order surprises a lot of people, so it is worth re-reading the API contract every time.
If you want batch averaging:
Choose the reduction deliberately because sum, mean, and batchmean can produce very different scales.
Compute It Manually for Clarity
When debugging or teaching, a manual calculation is often easier to reason about.
This makes the direction explicit and helps confirm whether a library call is being used correctly.
In practical code, clamp or smooth zero probabilities if needed, because log(0) is undefined.
Use Distribution Objects When You Have Actual Distributions
If you are working with probability distribution classes, use torch.distributions.kl_divergence.
This is the right choice for VAEs and other probabilistic models where the distributions are parameterized objects rather than plain categorical tensors.
Watch for Zeroes and Log Space
KL divergence becomes numerically fragile when probabilities approach zero. Practical safeguards include:
- normalizing tensors so they sum to one
- clamping values away from zero before logs
- using log-softmax outputs when appropriate
For model outputs, a common pattern is:
This is common in distillation and variational objectives because it is more stable than manually taking logs of already small probabilities.
Normalize Before Comparing Arbitrary Tensors
If your starting tensors are scores or counts rather than proper distributions, normalize them first. KL divergence assumes valid probability distributions, so the elements should be non-negative and sum to one across the comparison dimension. If that assumption is broken, the numeric result may still exist, but it no longer has the interpretation you think it has.
Common Pitfalls
- Reversing
PandQand computing the wrong KL direction. - Passing raw probabilities as the first argument to
F.kl_divinstead of log-probabilities. - Ignoring reduction choice and comparing losses with incompatible scales.
- Taking logs of zero or nearly zero probabilities without stabilization.
- Using tensor-based KL code when distribution objects would be clearer and safer.
Summary
- In PyTorch, KL divergence can be computed with tensor APIs or with distribution objects.
- '
F.kl_divexpects log-probabilities as input and usually probabilities as the target.' - The direction of KL matters and is easy to reverse by mistake.
- Manual computation is useful for validation and debugging.
- Stabilize probabilities and choose reduction carefully to avoid misleading results.

