Keras
custom callback
deep learning
model input
machine learning

Get Keras model input from inside a custom callback

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

In Keras, callbacks can access the model, metrics, and training lifecycle hooks, but they do not automatically receive the current batch input tensors. That surprises many users because the callback clearly runs during training, yet the batch data itself is not part of the default callback API.

If you need model input inside a callback, the practical answer is to pass the relevant data into the callback yourself or to expose it through a custom train_step. Which option is best depends on whether you need fixed reference data or the live batch currently being trained.

What Callbacks Get by Default

A callback such as on_epoch_end or on_train_batch_end receives:

  • the current epoch or batch index
  • a logs dictionary with metrics
  • access to self.model

What it does not receive is the current x and y batch as arguments.

That means code like "read the model input directly inside the callback" only works if:

  • you already stored that data elsewhere
  • you pass it into the callback constructor
  • your model explicitly exposes it during training

Pass Fixed Data into the Callback

If you only need to inspect predictions for a known sample or validation slice, the simplest pattern is to inject that data when the callback is created.

python
1import numpy as np
2import tensorflow as tf
3
4
5class InspectInputsCallback(tf.keras.callbacks.Callback):
6    def __init__(self, sample_x):
7        super().__init__()
8        self.sample_x = sample_x
9
10    def on_epoch_end(self, epoch, logs=None):
11        preds = self.model.predict(self.sample_x, verbose=0)
12        print(f"epoch={epoch} sample_shape={self.sample_x.shape}")
13        print(preds[:1])
14
15
16x = np.random.rand(32, 4).astype("float32")
17y = np.random.randint(0, 2, size=(32, 1)).astype("float32")
18
19model = tf.keras.Sequential([
20    tf.keras.layers.Input(shape=(4,)),
21    tf.keras.layers.Dense(8, activation="relu"),
22    tf.keras.layers.Dense(1, activation="sigmoid"),
23])
24
25model.compile(optimizer="adam", loss="binary_crossentropy")
26model.fit(x, y, epochs=2, callbacks=[InspectInputsCallback(x[:3])], verbose=0)

This does not expose the current live batch, but it solves many real monitoring tasks cleanly.

Access the Current Batch with a Custom Model

If you truly need the exact input batch flowing through training, the callback API alone is not enough. A common pattern is to subclass tf.keras.Model, override train_step, and store the current batch on the model before returning metrics.

python
1import tensorflow as tf
2
3
4class DebugModel(tf.keras.Model):
5    def train_step(self, data):
6        x, y = data
7        self.current_batch_x = x
8
9        with tf.GradientTape() as tape:
10            y_pred = self(x, training=True)
11            loss = self.compiled_loss(y, y_pred)
12
13        grads = tape.gradient(loss, self.trainable_variables)
14        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
15        self.compiled_metrics.update_state(y, y_pred)
16        return {m.name: m.result() for m in self.metrics}
17
18
19class BatchPeekCallback(tf.keras.callbacks.Callback):
20    def on_train_batch_end(self, batch, logs=None):
21        batch_x = getattr(self.model, "current_batch_x", None)
22        if batch_x is not None:
23            print(f"batch={batch} shape={batch_x.shape}")

This gives the callback a way to inspect the active training batch indirectly through the model instance.

Choose the Simpler Approach First

Most use cases do not need the exact in-flight batch. They need one of:

  • a reference input to visualize predictions
  • access to validation samples
  • monitoring of model outputs over time

For those cases, explicitly passing data into the callback is simpler and less fragile than customizing the training loop.

Use train_step only when you really need batch-level internals.

Common Pitfalls

  • Expecting logs to contain the raw input tensors. It usually contains only scalar metrics.
  • Trying to inspect the current batch in a standard callback without modifying the training path.
  • Passing huge datasets into the callback constructor when only a tiny reference sample is needed.
  • Using model.predict inside very frequent batch hooks and slowing training dramatically.
  • Storing large batch tensors on the model without understanding the memory cost.

Summary

  • Standard Keras callbacks do not receive the current input batch automatically.
  • For fixed sample inspection, pass the needed input data into the callback yourself.
  • For live batch access, expose the batch through a custom train_step.
  • Use the simple callback-constructor pattern unless batch-level internals are truly required.
  • Be careful about performance and memory when inspecting inputs during training.

Course illustration
Course illustration

All Rights Reserved.