How can I print the intermediate variables in the loss function in TensorFlow and Keras?
Master System Design with Codemia
Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.
Introduction
When you need to debug a custom loss in TensorFlow or Keras, ordinary Python print is often not enough. The loss may run inside a traced TensorFlow graph, which means Python-side printing can happen at the wrong time or not reflect the values used during the actual training step.
The reliable tool is tf.print, because it executes as part of the TensorFlow graph. For deeper debugging, you can also run the model eagerly so the loss behaves more like regular Python code during inspection.
Use tf.print Inside the Loss
The standard pattern is to compute your intermediate tensors and then print them with tf.print before returning the final loss value.
The summarize argument keeps long tensors from flooding the console. This matters quickly if your loss touches large batches or image tensors.
Why tf.print Works Better Than print
TensorFlow often traces the training step into a graph. Plain Python print runs during tracing, not necessarily during every execution of the loss. tf.print, by contrast, becomes part of the graph and prints the actual runtime tensor values.
That distinction is the main reason many developers think their loss is "not printing" even though their code looks correct. The graph is not behaving like ordinary sequential Python.
Debugging with Eager Execution
If you want easier step-by-step debugging, you can tell Keras to run training eagerly:
With run_eagerly=True, the loss executes in eager mode and regular Python debugging becomes more intuitive. This is very helpful during development, but it is usually slower, so it is best used as a temporary debugging aid rather than as a final training configuration.
Printing in a Custom Loss Class
If your loss is implemented as a subclass, the same idea applies:
This is useful when the loss carries configuration or when you want a reusable object instead of a function.
Keep the Output Useful
Loss debugging becomes unreadable very quickly if you print every tensor in full on every batch. A better strategy is to print only the values that help answer a specific question:
- are the intermediate tensors the shape you expect,
- are there NaNs or infinities,
- is the scaling wildly different from what you intended,
- is the loss dominated by one term in a composite objective.
When debugging NaNs, simple checks are often enough:
That turns a vague training failure into an immediate, localized error.
Common Pitfalls
- Using plain
printinside a traced loss and expecting runtime tensor values. - Printing huge tensors every step and making training unusably slow.
- Forgetting
run_eagerly=Trueis for debugging and leaving it enabled in normal training without meaning to. - Debugging only the final scalar loss and ignoring the intermediate terms that actually explain the bug.
- Assuming a missing print means the loss was not called. In graph mode, it may simply mean Python printing is happening at trace time instead of execution time.
Summary
- Use
tf.printto inspect intermediate tensors inside TensorFlow and Keras losses. - '
tf.printworks during graph execution, while ordinaryprintoften does not.' - Enable
run_eagerly=Truewhen you need easier interactive debugging. - Print only the most informative tensors to avoid overwhelming the console.
- Add numeric checks when you suspect NaNs, infinities, or exploding terms.

