DeepLab
fine-tuning
TensorFlow
custom dataset
image segmentation

Can i finetune deeplab to a custom dataset in tensorflow?

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

Yes, you can fine-tune DeepLab on a custom dataset in TensorFlow, and that is usually the right approach when you already have a segmentation problem with limited labeled data. The important part is not just loading a new image folder; you need a trainable DeepLab implementation, masks encoded with stable class IDs, and an output layer sized for your classes.

What fine-tuning actually means

DeepLab is a semantic segmentation architecture, not a single frozen model file. In practice, fine-tuning means starting from a DeepLab variant that already has pretrained backbone weights, then retraining the segmentation head and optionally part of the backbone on your own masks.

Your custom dataset must provide one mask per image, where each pixel contains a class index such as 0 for background, 1 for road, 2 for car, and so on. A normal color image mask is not enough unless you convert its colors into integer class IDs consistently.

For example, if you need three classes, your training target should contain only the values 0, 1, and 2. The final segmentation layer must also output three logits per pixel. If those counts do not match, training may start but the loss will be meaningless.

Build a reliable input pipeline first

Before touching the model, get the data pipeline correct. That is where many segmentation projects fail.

python
1import tensorflow as tf
2
3IMAGE_SIZE = (256, 256)
4NUM_CLASSES = 3
5
6def load_pair(image_path, mask_path):
7    image = tf.io.read_file(image_path)
8    image = tf.image.decode_jpeg(image, channels=3)
9    image = tf.image.resize(image, IMAGE_SIZE)
10    image = tf.cast(image, tf.float32) / 255.0
11
12    mask = tf.io.read_file(mask_path)
13    mask = tf.image.decode_png(mask, channels=1)
14    mask = tf.image.resize(mask, IMAGE_SIZE, method="nearest")
15    mask = tf.cast(mask, tf.int32)
16
17    tf.debugging.assert_less(tf.reduce_max(mask), NUM_CLASSES)
18    tf.debugging.assert_greater_equal(tf.reduce_min(mask), 0)
19    return image, mask
20
21image_paths = tf.constant([
22    "data/images/0001.jpg",
23    "data/images/0002.jpg",
24])
25mask_paths = tf.constant([
26    "data/masks/0001.png",
27    "data/masks/0002.png",
28])
29
30train_ds = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
31train_ds = train_ds.map(load_pair, num_parallel_calls=tf.data.AUTOTUNE)
32train_ds = train_ds.batch(4).prefetch(tf.data.AUTOTUNE)

Two details matter here. Masks must be resized with nearest interpolation so class IDs are not blended into invalid fractional values. Also, it is worth validating class IDs inside the pipeline early, because a single bad mask can waste hours of debugging.

Start from a pretrained DeepLab configuration

The easiest TensorFlow path is usually TensorFlow Model Garden, which provides pretrained segmentation experiment configs. A common starting point is a MobileNetV2 DeepLabV3 preset, then changing the dataset paths and class count for your own task.

python
1import tensorflow as tf
2import tensorflow_models as tfm
3
4exp_config = tfm.core.exp_factory.get_exp_config("mnv2_deeplabv3_pascal")
5exp_config.task.model.num_classes = 3
6exp_config.task.train_data.global_batch_size = 8
7exp_config.task.validation_data.global_batch_size = 8
8exp_config.task.train_data.input_path = "gs://my-bucket/segmentation/train*"
9exp_config.task.validation_data.input_path = "gs://my-bucket/segmentation/val*"
10
11optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

The exact training driver depends on the implementation you use, but the important idea is stable across all DeepLab variants: keep the pretrained backbone, reduce the learning rate, and change the segmentation head to match your dataset.

If your dataset is small, start by training only the head or using a very low learning rate on the backbone. If the domain shift is large, such as medical images versus street scenes, you may eventually unfreeze more of the encoder after the head begins to converge.

Evaluate segmentation, not just loss

Segmentation quality is better measured with class-aware metrics such as mean Intersection over Union than with accuracy alone. Pixel accuracy can look deceptively high when the background class dominates the image. Always inspect predicted masks visually and compute per-class metrics when possible.

A simple compile step for a trainable segmentation model looks like this:

python
1model.compile(
2    optimizer=tf.keras.optimizers.Adam(1e-4),
3    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
4    metrics=["accuracy"],
5)

That loss assumes the masks contain integer class IDs, not one-hot encoded vectors. If your pipeline emits one-hot targets, the loss choice must change accordingly.

Common Pitfalls

The biggest mistake is using RGB mask colors directly as labels. DeepLab expects class indices, so convert colors into numeric IDs before training.

Another common issue is resizing masks with bilinear interpolation. That creates blended label values and corrupts the target tensor. Use nearest-neighbor resizing for masks every time.

Developers also forget to update the number of output classes. Loading a pretrained checkpoint with the old segmentation head and then training against a different class count will produce shape mismatches or incorrect predictions.

Finally, do not assume pretrained inference models from model hubs are automatically ready for fine-tuning. Some are exported mainly for serving. Use a training-oriented implementation, such as Model Garden or another DeepLab codebase that exposes the full trainable graph.

Summary

  • You can fine-tune DeepLab on a custom TensorFlow dataset, and it is usually preferable to training from scratch.
  • Each mask must contain stable integer class IDs that match the model output classes.
  • Build and validate the tf.data pipeline before debugging the model.
  • Start from pretrained weights, keep the learning rate low, and adjust the segmentation head to NUM_CLASSES.
  • Evaluate with segmentation-aware metrics and visual mask inspection, not loss alone.

Course illustration
Course illustration

All Rights Reserved.