keras
model.fit
tf.dataset
validation_data
machine learning

Keras model.fit with tf.dataset API validation_data

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

model.fit works well with tf.data.Dataset, including for validation, but the dataset pipeline has to match Keras expectations. Most issues come from batching, repeating, and step-count configuration rather than from Keras itself.

Basic Pattern with Training and Validation Datasets

When you pass a tf.data.Dataset into model.fit, each element should usually be one of these shapes:

  • '(features, labels)'
  • '(features, labels, sample_weights)'

Here is a minimal working example:

python
1import tensorflow as tf
2
3x_train = tf.random.normal([1000, 8])
4y_train = tf.random.uniform([1000], maxval=2, dtype=tf.int32)
5
6x_val = tf.random.normal([200, 8])
7y_val = tf.random.uniform([200], maxval=2, dtype=tf.int32)
8
9train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
10val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(32)
11
12model = tf.keras.Sequential([
13    tf.keras.layers.Input(shape=(8,)),
14    tf.keras.layers.Dense(32, activation="relu"),
15    tf.keras.layers.Dense(2, activation="softmax"),
16])
17
18model.compile(
19    optimizer="adam",
20    loss="sparse_categorical_crossentropy",
21    metrics=["accuracy"],
22)
23
24model.fit(train_ds, validation_data=val_ds, epochs=3)

That is the normal setup. Keras consumes one dataset for training and a separate dataset for validation at the end of each epoch.

Add the Right Pipeline Stages

For training, common tf.data stages are:

  • 'shuffle'
  • 'batch'
  • 'prefetch'
python
1train_ds = (
2    tf.data.Dataset.from_tensor_slices((x_train, y_train))
3    .shuffle(1000)
4    .batch(32)
5    .prefetch(tf.data.AUTOTUNE)
6)

Validation data is usually simpler. You normally do not shuffle validation input because the goal is stable measurement, not training randomness.

python
1val_ds = (
2    tf.data.Dataset.from_tensor_slices((x_val, y_val))
3    .batch(32)
4    .prefetch(tf.data.AUTOTUNE)
5)

Keep preprocessing identical between training and validation unless the difference is intentional and well understood.

When repeat() Changes the Rules

If you call .repeat() on the training dataset, it becomes effectively infinite. In that case, Keras cannot infer when an epoch should end, so you must supply steps_per_epoch.

python
1train_ds = (
2    tf.data.Dataset.from_tensor_slices((x_train, y_train))
3    .shuffle(1000)
4    .batch(32)
5    .repeat()
6)
7
8model.fit(
9    train_ds,
10    validation_data=val_ds,
11    epochs=3,
12    steps_per_epoch=30,
13)

The same idea applies to validation. If validation_data repeats forever, specify validation_steps.

python
val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(32).repeat()

Without matching step counts, training may never finish an epoch or may fail with dataset exhaustion errors.

Validation Dataset Structure Must Match the Model

Keras does not treat validation data as a special format. It expects the same logical structure as training data.

For example, this works for multi-input models only if the dataset yields the correct feature structure:

python
1features = {
2    "age": tf.random.normal([100, 1]),
3    "score": tf.random.normal([100, 1]),
4}
5labels = tf.random.uniform([100], maxval=2, dtype=tf.int32)
6
7ds = tf.data.Dataset.from_tensor_slices((features, labels)).batch(16)

If the model expects named inputs, the validation dataset must provide the same names and shapes.

Use Datasets to Avoid Memory Blowups

One reason to prefer tf.data is that it scales beyond in-memory NumPy arrays. File-based pipelines, record parsing, and augmentation can all live in the dataset graph.

python
1def preprocess(x, y):
2    x = tf.cast(x, tf.float32)
3    return x / 255.0, y
4
5train_ds = raw_train_ds.map(preprocess).batch(64).prefetch(tf.data.AUTOTUNE)
6val_ds = raw_val_ds.map(preprocess).batch(64).prefetch(tf.data.AUTOTUNE)

The key point is consistency. Validation should run through the same normalization and shape logic as training, minus training-only randomness such as label-preserving augmentation.

Debugging Strategy

Before calling fit, inspect a single batch from both datasets.

python
1for xb, yb in train_ds.take(1):
2    print(xb.shape, yb.shape, xb.dtype, yb.dtype)
3
4for xb, yb in val_ds.take(1):
5    print(xb.shape, yb.shape, xb.dtype, yb.dtype)

That catches most mistakes quickly:

  • wrong label shape
  • missing batch dimension
  • wrong dtype
  • unexpected dictionary keys

If fit still behaves strangely, remove fancy pipeline steps and get a minimal batched dataset working first.

Common Pitfalls

  • Using .repeat() without setting steps_per_epoch or validation_steps.
  • Shuffling or augmenting validation data in ways that make metrics unstable.
  • Returning dataset elements that do not match the model input structure.
  • Assuming validation data can use different preprocessing than training.
  • Debugging at full pipeline complexity instead of inspecting one batch directly.

Summary

  • 'validation_data can be another tf.data.Dataset with the same feature-label structure as training.'
  • Training datasets often use shuffle, batch, and prefetch.
  • Validation datasets are usually batched and prefetched, but not shuffled.
  • If either dataset repeats indefinitely, explicit step counts are required.
  • Inspect one batch from each dataset before training to catch shape and dtype errors early.

Course illustration
Course illustration

All Rights Reserved.