TensorFlow
Callbacks
Custom Training Loop
Deep Learning
Machine Learning

Applying callbacks in a custom training loop in Tensorflow 2.0

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

TensorFlow callbacks work automatically with model.fit, but they do not magically wire themselves into a hand-written training loop. In a custom loop, you have to call the callback hooks yourself and provide the metrics that the callbacks expect. Once that is clear, you can still use familiar tools such as early stopping, checkpointing, and progress logging without giving up low-level control.

What Changes in a Custom Loop

With model.fit, TensorFlow owns the training lifecycle. It knows when an epoch starts, when a batch ends, and what logs should be passed to each callback. In a custom loop, you own that lifecycle, so you must drive the callbacks explicitly.

The usual setup looks like this:

  1. Create the model, optimizer, and loss.
  2. Build a CallbackList.
  3. Attach the model with set_model.
  4. Call lifecycle methods such as on_train_begin, on_epoch_begin, and on_epoch_end.
  5. Pass metric values in the logs dictionary.

If you skip those steps, the callbacks may exist but they will not do anything useful.

A Minimal Working Example

The example below trains a tiny regression model with a custom loop and a normal Keras callback list.

python
1import tensorflow as tf
2import numpy as np
3
4x = np.random.rand(256, 1).astype("float32")
5y = 3.0 * x + 2.0
6
7dataset = tf.data.Dataset.from_tensor_slices((x, y)).batch(32)
8
9model = tf.keras.Sequential([
10    tf.keras.layers.Dense(8, activation="relu"),
11    tf.keras.layers.Dense(1)
12])
13
14loss_fn = tf.keras.losses.MeanSquaredError()
15optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
16
17callbacks = [
18    tf.keras.callbacks.EarlyStopping(monitor="loss", patience=2),
19    tf.keras.callbacks.CSVLogger("training.log")
20]
21
22callback_list = tf.keras.callbacks.CallbackList(
23    callbacks,
24    add_history=True,
25    add_progbar=True,
26    model=model
27)
28
29callback_list.set_model(model)
30callback_list.set_params({
31    "epochs": 5,
32    "steps": len(list(dataset)),
33    "verbose": 1,
34    "metrics": ["loss"]
35})
36
37callback_list.on_train_begin()
38
39for epoch in range(5):
40    callback_list.on_epoch_begin(epoch)
41    epoch_loss = tf.keras.metrics.Mean()
42
43    for step, (batch_x, batch_y) in enumerate(dataset):
44        callback_list.on_train_batch_begin(step)
45
46        with tf.GradientTape() as tape:
47            predictions = model(batch_x, training=True)
48            loss = loss_fn(batch_y, predictions)
49
50        gradients = tape.gradient(loss, model.trainable_variables)
51        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
52
53        epoch_loss.update_state(loss)
54        logs = {"loss": float(epoch_loss.result().numpy())}
55        callback_list.on_train_batch_end(step, logs)
56
57    epoch_logs = {"loss": float(epoch_loss.result().numpy())}
58    callback_list.on_epoch_end(epoch, epoch_logs)
59
60    if model.stop_training:
61        break
62
63callback_list.on_train_end()

This is the core pattern. The callbacks remain standard Keras callbacks, but the loop decides when each event fires.

Feed the Right Metrics to the Callbacks

Callbacks such as EarlyStopping and ReduceLROnPlateau depend on metric names in logs. If you monitor val_loss but never compute or pass val_loss, the callback cannot make decisions.

For example, if you want validation-based early stopping, run a validation pass at the end of each epoch and include that value:

python
1val_loss_metric = tf.keras.metrics.Mean()
2
3for batch_x, batch_y in dataset:
4    predictions = model(batch_x, training=False)
5    val_loss = loss_fn(batch_y, predictions)
6    val_loss_metric.update_state(val_loss)
7
8epoch_logs = {
9    "loss": float(epoch_loss.result().numpy()),
10    "val_loss": float(val_loss_metric.result().numpy())
11}
12callback_list.on_epoch_end(epoch, epoch_logs)

The metric names must match the callback configuration exactly.

Custom Callbacks Still Work Well

You can also write your own callback class and use it in the same loop. That is often useful for printing domain-specific diagnostics or saving artifacts that model.fit does not understand.

python
1class WeightNormLogger(tf.keras.callbacks.Callback):
2    def on_epoch_end(self, epoch, logs=None):
3        kernel = self.model.layers[0].kernel
4        norm = tf.norm(kernel).numpy()
5        print(f"epoch={epoch} weight_norm={norm:.4f}")

Add it to the callbacks list and it will receive the same lifecycle events as built-in callbacks.

Know When a Custom Loop Is Worth It

Custom loops are best when you need manual gradient accumulation, multiple optimizers, reinforcement learning updates, or non-standard batch logic. If the training flow is ordinary supervised learning, model.fit is still simpler and gives you callback support for free. The goal is control where you actually need it, not replacing higher-level APIs by habit.

Common Pitfalls

  • Creating callbacks but never wrapping them in a CallbackList and never calling lifecycle methods.
  • Monitoring val_loss or another metric name that is never included in the logs dictionary.
  • Forgetting to call set_model and set_params, which leaves some callbacks without required context.
  • Ignoring model.stop_training after EarlyStopping requests termination.
  • Rewriting callback behavior manually when standard Keras callbacks would already handle it cleanly.

Summary

  • Callbacks work in custom TensorFlow loops, but you must drive them explicitly.
  • Use CallbackList plus lifecycle hooks such as on_train_begin and on_epoch_end.
  • Pass metric values through logs using names that match the callback configuration.
  • Built-in and custom callbacks can both be reused in this pattern.
  • Choose a custom loop only when you need training behavior that model.fit cannot express clearly.

Course illustration
Course illustration

All Rights Reserved.