Keras
Image Segmentation
Image Masks
Machine Learning
Labels

How to load Image Masks Labels for Image Segmentation in Keras

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

For image segmentation, the input image is only half of the dataset. The real target is the mask, where each pixel stores a class label for the corresponding input pixel. Loading those masks correctly matters more than it does in image classification because resizing, interpolation, or channel handling mistakes can silently corrupt every label in the training set.

Organize Images and Masks as Paired Files

The safest layout is one image file and one mask file with matching names. For example:

text
1dataset/
2  images/
3    0001.png
4    0002.png
5  masks/
6    0001.png
7    0002.png

With that structure, you can derive the mask path directly from the image path and keep the pairing deterministic.

Build a tf.data Pipeline

In Keras and TensorFlow, tf.data is the best tool for loading image and mask pairs. The important part is to decode masks as integer labels and resize them with nearest-neighbor interpolation.

python
1from pathlib import Path
2import tensorflow as tf
3
4IMAGE_SIZE = (256, 256)
5NUM_CLASSES = 3
6
7image_paths = sorted(str(p) for p in Path("dataset/images").glob("*.png"))
8mask_paths = [path.replace("/images/", "/masks/") for path in image_paths]
9
10
11def load_example(image_path: tf.Tensor, mask_path: tf.Tensor):
12    image_bytes = tf.io.read_file(image_path)
13    mask_bytes = tf.io.read_file(mask_path)
14
15    image = tf.image.decode_png(image_bytes, channels=3)
16    mask = tf.image.decode_png(mask_bytes, channels=1)
17
18    image = tf.image.resize(image, IMAGE_SIZE)
19    mask = tf.image.resize(mask, IMAGE_SIZE, method="nearest")
20
21    image = tf.cast(image, tf.float32) / 255.0
22    mask = tf.cast(mask, tf.int32)
23
24    return image, mask
25
26
27dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
28dataset = dataset.map(load_example, num_parallel_calls=tf.data.AUTOTUNE)
29dataset = dataset.batch(8).prefetch(tf.data.AUTOTUNE)

Using method="nearest" for the mask is essential. Segmentation labels are discrete categories, not continuous image intensities.

Choose Between Integer Masks and One-Hot Masks

Most Keras segmentation models can train directly on integer mask labels when paired with a sparse loss such as SparseCategoricalCrossentropy. That is often simpler and more memory-efficient than expanding the mask to one-hot format.

python
1model.compile(
2    optimizer="adam",
3    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
4    metrics=["accuracy"],
5)

If your model or loss expects one-hot masks, convert them inside the dataset pipeline:

python
1def one_hot_encode(image, mask):
2    mask = tf.squeeze(mask, axis=-1)
3    mask = tf.one_hot(mask, depth=NUM_CLASSES)
4    return image, mask
5
6
7one_hot_dataset = dataset.map(one_hot_encode, num_parallel_calls=tf.data.AUTOTUNE)

The key is to know which label representation your loss function expects before you start training.

Keep Augmentation Synchronized

If you augment images, apply the same geometric transform to the mask. Rotating an image without rotating its mask creates invalid training pairs.

python
1def random_flip(image, mask):
2    should_flip = tf.random.uniform(()) > 0.5
3    image = tf.cond(should_flip, lambda: tf.image.flip_left_right(image), lambda: image)
4    mask = tf.cond(should_flip, lambda: tf.image.flip_left_right(mask), lambda: mask)
5    return image, mask
6
7
8dataset = dataset.map(random_flip, num_parallel_calls=tf.data.AUTOTUNE)

Color augmentations usually apply only to the image, while spatial transforms must stay synchronized between image and mask.

Common Pitfalls

The most damaging mistake is resizing masks with bilinear interpolation. That creates new pixel values that never existed in the label map, which turns clean class IDs into blended nonsense.

Another issue is reading masks as RGB when the dataset stores class IDs in a single channel. If the mask is really an indexed label image, keep it single-channel. If the dataset uses color-coded masks, add a conversion step that maps each color to a class ID.

Shape mismatches are also common. Many models expect masks of shape (height, width, 1) for sparse losses or (height, width, num_classes) for one-hot training. Check the final tensor shapes before starting a long training run.

Finally, verify that image and mask filenames are paired correctly. A perfectly valid pipeline can still train on garbage if 0007.png is accidentally matched with the mask for 0008.png.

Summary

  • Store each image and mask as a deterministic pair.
  • Load segmentation data with tf.data so reading, resizing, batching, and prefetching stay explicit.
  • Resize masks with nearest-neighbor interpolation, not bilinear interpolation.
  • Match the mask representation to the loss function you plan to use.
  • Apply geometric augmentations to images and masks together so labels remain aligned.

Course illustration
Course illustration

All Rights Reserved.