Training Dynamics

Topics Covered

Vanishing and Exploding Gradients

The Core Problem

Why This Matters in Practice

Gradient Clipping, The Exploding Gradient Fix

Weight Initialization

Why Zero Initialization Fails

The Variance Analysis

Kaiming Initialization, For ReLU Networks

Why Getting Initialization Wrong Is Expensive

Initialization Code and Sanity Checks

Normalization Techniques

The Root Problem: Internal Covariate Shift

Batch Normalization

Layer Normalization

RMSNorm, The Modern Simplification

Pre-LN vs Post-LN

Mixed Precision Training

Why Float32 Is Not the Answer

The Float16 Problem: Numerical Underflow

Loss Scaling: The Key Trick

BFloat16, Why It's Better Than Float16 for Training

The Master Weights Pattern

Learning Rate Schedules

Practice: See It In Action

Deep networks fail to train more often than they succeed, and not because the math is wrong. The math is fine. The problem is gradient flow. Every technique covered in this lesson, initialization, normalization, mixed precision, exists specifically to keep gradients well-behaved. Understanding the failure modes first makes every subsequent solution obvious rather than mysterious.

The Core Problem

During backpropagation, gradients are computed by repeatedly applying the chain rule through each layer. For a network with LL layers, the gradient at layer 1 involves multiplying LL Jacobian matrices together:

Lh1=LhLhLhL1h2h1\dfrac{\partial \mathcal{L}}{\partial h_1} = \dfrac{\partial \mathcal{L}}{\partial h_L} \cdot \dfrac{\partial h_L}{\partial h_{L-1}} \cdots \dfrac{\partial h_2}{\partial h_1}

If each Jacobian has typical singular value σ<1\sigma \lt 1, then after LL multiplications the gradient shrinks by σL\sigma^L, exponentially toward zero. With sigmoid activations and σ0.25\sigma \approx 0.25, a 10-layer network attenuates the gradient by 0.25101060.25^{10} \approx 10^{-6}.

If each Jacobian has typical singular value σ>1\sigma > 1, the gradient grows by σL\sigma^L, exponentially toward infinity. The gradient explodes, weights receive enormous updates, and training diverges.

Gradient magnitude by layer depth
24681012141618201e-91e-81e-71e-61e-51e-41e-30.010.101Layer index (1 = closest to input)|gradient|Plain deep netWith layer norm + residuals
Without normalization, gradients vanish in early layers of a deep network — they learn nothing. Layer norm and residual connections keep flow even.

Why This Matters in Practice

Vanishing gradients mean the first layers of the network essentially don't train. The parameters that encode low-level features, in a language model, things like basic syntax and token patterns, receive no learning signal. The network can memorize patterns in its later layers but cannot build abstractions from scratch.

Exploding gradients manifest differently: the loss goes to NaN after a few steps, or you observe loss values jumping by orders of magnitude between batches. The training run is simply broken.

Gradient Clipping, The Exploding Gradient Fix

For exploding gradients, the standard solution is gradient clipping:

python
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

This rescales the gradient vector if its global norm exceeds max_norm. It's a blunt instrument. It prevents runaway updates without addressing the underlying cause, but it works reliably. Nearly every transformer training recipe includes gradient clipping at 1.0.

The intuition: if the gradient is pointing in a sensible direction but is too large, scale it down to have norm 1. Direction preserved, magnitude controlled.

For vanishing gradients, clipping does nothing, you can't amplify a near-zero gradient. The real solutions are initialization and normalization.

Common Pitfall

If your training loss goes NaN in the first few steps, check gradient norms before clipping. A gradient norm of 10^6 on step 1 means your initialization is too large or your learning rate is too high. If loss goes NaN after 10,000 steps, check for a degenerate input batch (all zeros, NaN features) or a numerical issue in your loss function (\log(0), divide by zero).