Keras
TensorBoard
gradient vanishing
gradient explosion
neural network debugging

How to monitor gradient vanish and explosion in keras with tensorboard?

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

TensorBoard can show you whether training is producing gradients that are consistently near zero or wildly large, but Keras does not log gradient values automatically in the simplest training setup. The practical solution is to compute gradients inside a custom training step and write gradient summaries, usually histogram or norm summaries, to TensorBoard.

Watch Gradient Norms, Not Just Loss

Vanishing and exploding gradients often show up before the loss makes the problem obvious. A small gradient norm across many layers suggests the optimizer is barely moving weights. Extremely large norms suggest unstable updates.

A useful strategy is to log the norm of each gradient tensor during training.

Override train_step and Write Summaries

Here is a minimal example with TensorFlow Keras:

python
1import tensorflow as tf
2from tensorflow import keras
3
4class GradientLoggingModel(keras.Model):
5    def __init__(self, *args, log_dir="logs/gradients", **kwargs):
6        super().__init__(*args, **kwargs)
7        self.summary_writer = tf.summary.create_file_writer(log_dir)
8        self.step_counter = 0
9
10    def train_step(self, data):
11        x, y = data
12
13        with tf.GradientTape() as tape:
14            y_pred = self(x, training=True)
15            loss = self.compiled_loss(y, y_pred)
16
17        gradients = tape.gradient(loss, self.trainable_variables)
18        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
19        self.compiled_metrics.update_state(y, y_pred)
20
21        with self.summary_writer.as_default():
22            for variable, grad in zip(self.trainable_variables, gradients):
23                if grad is not None:
24                    tf.summary.scalar(
25                        f"grad_norm/{variable.name}",
26                        tf.norm(grad),
27                        step=self.step_counter
28                    )
29
30        self.step_counter += 1
31        return {m.name: m.result() for m in self.metrics}

This logs one scalar norm per trainable variable at each step.

Train the model as usual and point TensorBoard at the log directory:

python
1inputs = keras.Input(shape=(32,))
2outputs = keras.layers.Dense(1)(inputs)
3model = GradientLoggingModel(inputs=inputs, outputs=outputs)
4model.compile(optimizer="adam", loss="mse", metrics=["mae"])
5model.fit(x_train, y_train, epochs=5, batch_size=32)

Then run:

bash
tensorboard --logdir logs

Look for norms that are persistently tiny across layers or that spike to extreme values during training.

Interpret the Signals Carefully

Gradient problems are often caused by a combination of factors:

  • activation choices
  • initialization
  • learning rate
  • model depth
  • normalization strategy

TensorBoard helps you observe the symptom. It does not remove the need to reason about the training setup.

If norms explode, consider learning-rate reduction, gradient clipping, or architecture changes. If they vanish, examine activation functions, initialization, and whether the network depth is appropriate.

Gradient Monitoring Is a Diagnostic Layer

TensorBoard does not fix gradient problems by itself. Its value is that it lets you see whether training instability is structural and persistent instead of guessing from loss curves alone.

Use the Signal to Drive Training Changes

If TensorBoard shows exploding gradients, likely next steps include lowering the learning rate or adding gradient clipping. If it shows persistent vanishing, inspect initialization, activation choice, normalization, and network depth instead of only tweaking logging.

Common Pitfalls

  • Assuming the default Keras TensorBoard callback logs gradients automatically. In simple setups, it usually does not.
  • Looking only at loss curves and missing the fact that gradients are near zero everywhere.
  • Logging raw gradient values without summarizing them, which can be harder to interpret at scale.
  • Treating one noisy batch as proof of exploding gradients instead of looking for sustained patterns.
  • Trying to fix the problem only in TensorBoard instead of changing the model or optimizer configuration.

Summary

  • To monitor vanishing or exploding gradients in Keras, compute gradients explicitly in a custom training step.
  • Log gradient norms or histograms to TensorBoard.
  • Use TensorBoard trends to spot gradients that are persistently tiny or abnormally large.
  • Interpret those signals together with model architecture and optimizer settings.
  • TensorBoard gives visibility; the actual fix usually comes from training or model changes.

Course illustration
Course illustration

All Rights Reserved.