Keras
TensorFlow
data prefetching
machine learning
deep learning

Could Keras prefetch data like tensorflow Dataset?

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

Yes, Keras can prefetch data because Keras uses tf.data.Dataset directly as an input pipeline. Since Keras is fully integrated with TensorFlow, you create a tf.data.Dataset with .prefetch(), .batch(), .shuffle(), and other transformations, then pass it directly to model.fit(). There is no separate "Keras prefetch" — you use TensorFlow's tf.data API, and Keras consumes it natively.

Basic Prefetching with tf.data

python
1import tensorflow as tf
2
3# Create a dataset from tensors
4dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
5
6# Apply transformations including prefetch
7dataset = dataset.shuffle(buffer_size=10000)
8dataset = dataset.batch(32)
9dataset = dataset.prefetch(tf.data.AUTOTUNE)
10
11# Pass directly to Keras model.fit()
12model.fit(dataset, epochs=10)

.prefetch(tf.data.AUTOTUNE) overlaps data loading and preprocessing with model training. While the GPU trains on batch N, the CPU prepares batch N+1 in parallel. AUTOTUNE lets TensorFlow dynamically adjust the buffer size for optimal performance.

How Prefetch Works

 
1Without prefetch:
2  [Load Batch 1] [Train Batch 1] [Load Batch 2] [Train Batch 2] ...
3  CPU idle ↑      GPU idle ↑      CPU idle ↑      GPU idle ↑
4
5With prefetch:
6  [Load Batch 1] [Load Batch 2] [Load Batch 3] ...CPU
7                 [Train Batch 1] [Train Batch 2] ...GPU
8  Overlap → no idle time

Prefetching eliminates the idle time between data loading and training. The speedup is significant when data loading involves disk I/O, image decoding, or augmentation.

Complete Keras Pipeline with Prefetch

python
1import tensorflow as tf
2from tensorflow import keras
3
4# Build model
5model = keras.Sequential([
6    keras.layers.Dense(128, activation='relu', input_shape=(784,)),
7    keras.layers.Dropout(0.2),
8    keras.layers.Dense(10, activation='softmax')
9])
10
11model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
12
13# Build efficient input pipeline
14train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
15train_ds = train_ds.shuffle(10000).batch(64).prefetch(tf.data.AUTOTUNE)
16
17val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
18val_ds = val_ds.batch(64).prefetch(tf.data.AUTOTUNE)
19
20# Train — Keras handles tf.data.Dataset natively
21model.fit(train_ds, validation_data=val_ds, epochs=10)

Image Data Pipeline with Augmentation

python
1def load_and_preprocess(file_path, label):
2    image = tf.io.read_file(file_path)
3    image = tf.image.decode_jpeg(image, channels=3)
4    image = tf.image.resize(image, [224, 224])
5    image = image / 255.0
6    return image, label
7
8def augment(image, label):
9    image = tf.image.random_flip_left_right(image)
10    image = tf.image.random_brightness(image, max_delta=0.2)
11    image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
12    return image, label
13
14# Build pipeline
15train_ds = tf.data.Dataset.from_tensor_slices((file_paths, labels))
16train_ds = (train_ds
17    .shuffle(len(file_paths))
18    .map(load_and_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
19    .map(augment, num_parallel_calls=tf.data.AUTOTUNE)
20    .batch(32)
21    .prefetch(tf.data.AUTOTUNE)
22)
23
24model.fit(train_ds, epochs=20)

num_parallel_calls=tf.data.AUTOTUNE on .map() parallelizes preprocessing across CPU cores. Combined with .prefetch(), this maximizes throughput.

Using keras.utils.image_dataset_from_directory

Keras provides a high-level utility that creates a prefetched dataset from a directory of images:

python
1train_ds = keras.utils.image_dataset_from_directory(
2    'data/train',
3    image_size=(224, 224),
4    batch_size=32,
5    label_mode='categorical'
6)
7
8# The returned dataset is already batched, add prefetch
9train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
10
11model.fit(train_ds, epochs=10)

Keras Sequence vs tf.data.Dataset

Before tf.data integration, Keras used keras.utils.Sequence for custom data loading:

python
1# Old approach — Keras Sequence (still works but less efficient)
2class DataGenerator(keras.utils.Sequence):
3    def __init__(self, x, y, batch_size=32):
4        self.x, self.y = x, y
5        self.batch_size = batch_size
6
7    def __len__(self):
8        return len(self.x) // self.batch_size
9
10    def __getitem__(self, idx):
11        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
12        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
13        return batch_x, batch_y
14
15# Multiprocessing provides some prefetching
16model.fit(DataGenerator(x_train, y_train), workers=4, use_multiprocessing=True)

tf.data.Dataset with .prefetch() is faster and more flexible than Sequence because it runs entirely in the TensorFlow C++ runtime with better parallelism and no Python GIL bottleneck.

Performance Optimization Tips

python
1# Full optimized pipeline
2dataset = (tf.data.Dataset.from_tensor_slices((x, y))
3    .cache()                                          # Cache after first epoch
4    .shuffle(buffer_size=len(x))                     # Full shuffle
5    .batch(64)                                        # Batch
6    .map(augment, num_parallel_calls=tf.data.AUTOTUNE) # Augment after batching
7    .prefetch(tf.data.AUTOTUNE)                      # Prefetch
8)

Ordering matters:

  1. .cache() — stores data in memory after first read (put before augmentation)
  2. .shuffle() — randomize order (before batching)
  3. .batch() — group into batches
  4. .map() — apply augmentation (after batching for vectorized ops)
  5. .prefetch() — always last in the chain

Common Pitfalls

  • Forgetting .prefetch() at the end of the pipeline: Without prefetch, the GPU waits idle while the CPU loads the next batch. Always add .prefetch(tf.data.AUTOTUNE) as the last transformation in your pipeline.
  • Placing .cache() after augmentation: If you cache after random augmentation, the same augmented images are used every epoch. Place .cache() before augmentation so each epoch gets fresh random transforms.
  • Using Sequence with use_multiprocessing=True on Windows: Multiprocessing in Keras Sequence uses fork(), which is not available on Windows. Use tf.data.Dataset with .prefetch() instead, which works on all platforms.
  • Setting num_parallel_calls to a fixed high number: Hardcoding num_parallel_calls=16 may over-allocate CPU resources. Use tf.data.AUTOTUNE to let TensorFlow choose the optimal parallelism based on available hardware.
  • Shuffling with a small buffer size: .shuffle(buffer_size=100) on a dataset of 50,000 samples provides poor randomization — only 100 elements are considered at a time. Use buffer_size=len(dataset) for full shuffling, or at least 10x the batch size for adequate randomization.

Summary

  • Keras uses tf.data.Dataset directly — call .prefetch(tf.data.AUTOTUNE) to overlap data loading with training
  • .prefetch() should always be the last transformation in your data pipeline
  • Use num_parallel_calls=tf.data.AUTOTUNE on .map() for parallel preprocessing
  • keras.utils.image_dataset_from_directory creates a batched dataset from image folders
  • .cache() stores data in memory to avoid re-reading from disk each epoch
  • tf.data.Dataset with prefetch is faster than keras.utils.Sequence due to C++ runtime execution

Course illustration
Course illustration

All Rights Reserved.