TensorFlow
Keras
loss function
debugging
intermediate variables

How can I print the intermediate variables in the loss function in TensorFlow and Keras?

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

When you need to debug a custom loss in TensorFlow or Keras, ordinary Python print is often not enough. The loss may run inside a traced TensorFlow graph, which means Python-side printing can happen at the wrong time or not reflect the values used during the actual training step.

The reliable tool is tf.print, because it executes as part of the TensorFlow graph. For deeper debugging, you can also run the model eagerly so the loss behaves more like regular Python code during inspection.

Use tf.print Inside the Loss

The standard pattern is to compute your intermediate tensors and then print them with tf.print before returning the final loss value.

python
1import tensorflow as tf
2
3
4def custom_loss(y_true, y_pred):
5    diff = y_true - y_pred
6    squared = tf.square(diff)
7    mean_loss = tf.reduce_mean(squared)
8
9    tf.print("diff:", diff, summarize=8)
10    tf.print("squared:", squared, summarize=8)
11    tf.print("mean_loss:", mean_loss)
12
13    return mean_loss
14
15
16model = tf.keras.Sequential(
17    [
18        tf.keras.layers.Input(shape=(1,)),
19        tf.keras.layers.Dense(1),
20    ]
21)
22
23model.compile(optimizer="adam", loss=custom_loss)
24
25x = tf.constant([[1.0], [2.0], [3.0]])
26y = tf.constant([[2.0], [4.0], [6.0]])
27
28model.fit(x, y, epochs=1, verbose=0)

The summarize argument keeps long tensors from flooding the console. This matters quickly if your loss touches large batches or image tensors.

Why tf.print Works Better Than print

TensorFlow often traces the training step into a graph. Plain Python print runs during tracing, not necessarily during every execution of the loss. tf.print, by contrast, becomes part of the graph and prints the actual runtime tensor values.

That distinction is the main reason many developers think their loss is "not printing" even though their code looks correct. The graph is not behaving like ordinary sequential Python.

Debugging with Eager Execution

If you want easier step-by-step debugging, you can tell Keras to run training eagerly:

python
1model.compile(
2    optimizer="adam",
3    loss=custom_loss,
4    run_eagerly=True,
5)

With run_eagerly=True, the loss executes in eager mode and regular Python debugging becomes more intuitive. This is very helpful during development, but it is usually slower, so it is best used as a temporary debugging aid rather than as a final training configuration.

Printing in a Custom Loss Class

If your loss is implemented as a subclass, the same idea applies:

python
1class DebugLoss(tf.keras.losses.Loss):
2    def call(self, y_true, y_pred):
3        abs_error = tf.abs(y_true - y_pred)
4        tf.print("abs_error:", abs_error, summarize=8)
5        return tf.reduce_mean(abs_error)

This is useful when the loss carries configuration or when you want a reusable object instead of a function.

Keep the Output Useful

Loss debugging becomes unreadable very quickly if you print every tensor in full on every batch. A better strategy is to print only the values that help answer a specific question:

  • are the intermediate tensors the shape you expect,
  • are there NaNs or infinities,
  • is the scaling wildly different from what you intended,
  • is the loss dominated by one term in a composite objective.

When debugging NaNs, simple checks are often enough:

python
tf.debugging.check_numerics(mean_loss, message="Loss contains NaN or Inf")

That turns a vague training failure into an immediate, localized error.

Common Pitfalls

  • Using plain print inside a traced loss and expecting runtime tensor values.
  • Printing huge tensors every step and making training unusably slow.
  • Forgetting run_eagerly=True is for debugging and leaving it enabled in normal training without meaning to.
  • Debugging only the final scalar loss and ignoring the intermediate terms that actually explain the bug.
  • Assuming a missing print means the loss was not called. In graph mode, it may simply mean Python printing is happening at trace time instead of execution time.

Summary

  • Use tf.print to inspect intermediate tensors inside TensorFlow and Keras losses.
  • 'tf.print works during graph execution, while ordinary print often does not.'
  • Enable run_eagerly=True when you need easier interactive debugging.
  • Print only the most informative tensors to avoid overwhelming the console.
  • Add numeric checks when you suspect NaNs, infinities, or exploding terms.

Course illustration
Course illustration

All Rights Reserved.