TensorFlow
imgaug
data augmentation
tf.data.Dataset
TensorFlow 2.0

how to apply imgaug augmentation to tf.dataDataset in Tensorflow 2.0

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

Using imgaug with tf.data.Dataset in TensorFlow 2.0 is a practical way to combine advanced augmentations with scalable input pipelines. The main challenge is that imgaug operates on NumPy arrays while tf.data pipelines carry tensors. This guide shows a robust bridge pattern and highlights performance and correctness checks.

Bridge TensorFlow Tensors to imgaug

tf.numpy_function allows Python augmentation code inside a dataset map stage. After the call, set static shape and normalize values so model input signatures remain stable.

python
1import imgaug.augmenters as iaa
2import numpy as np
3import tensorflow as tf
4
5IMAGE_SIZE = (224, 224)
6
7seq = iaa.Sequential(
8    [
9        iaa.Fliplr(0.5),
10        iaa.Affine(rotate=(-10, 10), scale=(0.95, 1.05)),
11        iaa.Multiply((0.8, 1.2)),
12    ]
13)
14
15
16def _aug_np(image_np):
17    image_np = image_np.astype(np.uint8)
18    out = seq(image=image_np)
19    return out.astype(np.float32)
20
21
22def augment_tf(image, label):
23    image = tf.numpy_function(_aug_np, [image], tf.float32)
24    image.set_shape([IMAGE_SIZE[0], IMAGE_SIZE[1], 3])
25    image = tf.clip_by_value(image / 255.0, 0.0, 1.0)
26    return image, label

If shape metadata is missing, later layers can fail during tracing or batching.

Assemble the Dataset Pipeline

Keep decoding and resizing with TensorFlow ops for performance, then apply imgaug in a map stage.

python
1import tensorflow as tf
2
3
4def decode_example(path, label):
5    image = tf.io.read_file(path)
6    image = tf.image.decode_jpeg(image, channels=3)
7    image = tf.image.resize(image, IMAGE_SIZE)
8    image = tf.cast(image, tf.float32)
9    return image, label
10
11
12paths = tf.constant(["data/img1.jpg", "data/img2.jpg", "data/img3.jpg"])
13labels = tf.constant([0, 1, 0], dtype=tf.int32)
14
15ds = tf.data.Dataset.from_tensor_slices((paths, labels))
16ds = ds.map(decode_example, num_parallel_calls=tf.data.AUTOTUNE)
17ds = ds.map(augment_tf, num_parallel_calls=tf.data.AUTOTUNE)
18ds = ds.shuffle(256)
19ds = ds.batch(16)
20ds = ds.prefetch(tf.data.AUTOTUNE)
21
22for images, y in ds.take(1):
23    print(images.shape, y.numpy())

This pattern keeps the pipeline readable and suitable for most image classification projects.

Keep Augmentation Deterministic During Debugging

For reproducible experiments, seed every random source.

python
1import imgaug as ia
2import numpy as np
3import random
4import tensorflow as tf
5
6SEED = 123
7ia.seed(SEED)
8np.random.seed(SEED)
9random.seed(SEED)
10tf.random.set_seed(SEED)

Determinism is especially useful when validating whether a metric change comes from code edits or random augmentation variance.

If your task uses segmentation masks or bounding boxes, image-only augmentation is incorrect. Geometric transforms must be applied consistently to every paired annotation.

For object detection, transform boxes with the same augmentation parameters used for the image. For segmentation, use nearest-neighbor interpolation for masks to preserve class ids.

Design this as one function that consumes image plus annotations and returns the full synchronized result. Splitting transforms across separate stages often causes subtle training corruption.

Throughput Tuning Checklist

Once correctness is stable, profile pipeline throughput before scaling model complexity. Measure steps per second with and without augmentation to quantify overhead. Increase num_parallel_calls, use larger batch sizes where memory allows, and keep prefetch enabled. If CPU becomes the bottleneck, move simple color and flip operations to tf.image and keep only specialized transforms in imgaug. Cache decoded data when source files are small enough, but avoid caching fully augmented outputs unless memory budget permits. A few pipeline adjustments can recover substantial training time without changing model accuracy.

Common Pitfalls

A frequent mistake is feeding float tensors to imgaug without checking value range. Many augmenters assume uint8 in the range zero to 255.

Another issue is throughput drop from heavy Python code in tf.numpy_function. Use native tf.image ops for simple transforms and keep imgaug for features TensorFlow lacks.

People also forget output shape restoration, causing shape inference failures in later layers.

Finally, deterministic seeds can still vary slightly across hardware backends. Use seeds as a debugging aid, not as an absolute guarantee of identical training curves.

Summary

  • Use tf.numpy_function to bridge imgaug into tf.data pipelines.
  • Restore tensor shape and normalize ranges immediately after augmentation.
  • Keep decode and resize in TensorFlow ops for better throughput.
  • Seed all random sources to reduce experiment variance.
  • Apply synchronized transforms to images and labels for detection or segmentation tasks.

Course illustration
Course illustration

All Rights Reserved.