Dice `Loss`
TensorFlow
Keras
Deep Learning
Machine Learning

Correct Implementation of Dice `Loss` in Tensorflow / Keras

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

Dice loss is popular in image segmentation because it measures overlap directly instead of treating every pixel independently. A correct TensorFlow or Keras implementation should work on probabilities, handle batch dimensions carefully, and include smoothing so empty masks do not cause divide-by-zero problems.

What Dice Loss Measures

The Dice coefficient compares the overlap between a predicted mask and a target mask. Higher overlap means a better score, so Dice loss is usually defined as 1 - dice_coefficient.

For binary segmentation, the most common implementation uses soft probabilities rather than hard thresholded predictions. That matters because gradients need to flow through the loss. If you round or threshold y_pred inside the loss, training usually becomes unstable or stops improving.

A Correct Binary Dice Loss

This implementation works for binary segmentation with tensors shaped like batch, height, width, and channels:

python
1import tensorflow as tf
2
3
4def dice_loss(y_true, y_pred, smooth=1e-6):
5    y_true = tf.cast(y_true, tf.float32)
6    y_pred = tf.cast(y_pred, tf.float32)
7
8    y_true = tf.reshape(y_true, [tf.shape(y_true)[0], -1])
9    y_pred = tf.reshape(y_pred, [tf.shape(y_pred)[0], -1])
10
11    intersection = tf.reduce_sum(y_true * y_pred, axis=1)
12    denominator = tf.reduce_sum(y_true + y_pred, axis=1)
13
14    dice = (2.0 * intersection + smooth) / (denominator + smooth)
15    return 1.0 - tf.reduce_mean(dice)

A few details matter here:

  • both tensors are cast to floating point
  • spatial dimensions are flattened per sample, not across the whole batch
  • the final loss averages sample-level Dice scores
  • 'smooth protects against empty masks'

Flattening the entire batch into one vector can hide per-image behavior and produce misleading gradients when masks vary a lot across the batch.

Using the Loss in a Keras Model

For binary segmentation, the last layer usually emits one channel with a sigmoid activation:

python
1import tensorflow as tf
2
3inputs = tf.keras.Input(shape=(128, 128, 1))
4x = tf.keras.layers.Conv2D(16, 3, padding="same", activation="relu")(inputs)
5x = tf.keras.layers.Conv2D(16, 3, padding="same", activation="relu")(x)
6outputs = tf.keras.layers.Conv2D(1, 1, activation="sigmoid")(x)
7
8model = tf.keras.Model(inputs, outputs)
9model.compile(optimizer="adam", loss=dice_loss)

In this setup, y_true should contain masks with values such as 0 and 1, and y_pred should be the model probabilities produced by sigmoid.

Multiclass Dice Is Different

A common mistake is reusing binary Dice code for multiclass segmentation without adjusting it. For multiclass output, the model often produces one channel per class with softmax, and the loss should usually compute Dice per class and then average across classes.

That means axis handling changes. If your masks are one-hot encoded, each class channel should contribute separately. If your labels are integer class ids, convert them to one-hot form before computing a multiclass Dice score.

In practice, many teams combine Dice loss with cross-entropy. Cross-entropy gives strong pixel-wise gradients early in training, while Dice helps optimize overlap when the class distribution is imbalanced.

Why Dice Loss Helps With Imbalanced Masks

Segmentation datasets often contain a large background and a very small foreground region. Standard losses can become dominated by background pixels. Dice loss reduces that effect because it focuses on overlap rather than counting every pixel equally.

That is why Dice is especially common in medical imaging, defect detection, and satellite segmentation where the object of interest may occupy only a small part of the image.

Common Pitfalls

The biggest mistake is thresholding predictions inside the loss with something like y_pred > 0.5. That destroys differentiability.

Another issue is applying Dice directly to logits instead of probabilities. If the model output is raw logits, wrap them with sigmoid or softmax before the Dice computation, or let the model output probabilities directly.

Developers also flatten the entire batch into one vector and unknowingly change the optimization behavior. Compute Dice per sample, then average.

Summary

  • Dice loss should operate on soft probabilities, not thresholded predictions.
  • Flatten spatial dimensions per sample so batch averaging stays meaningful.
  • Add a small smoothing term to avoid divide-by-zero cases.
  • Use sigmoid for binary segmentation and adjust the implementation for multiclass output.
  • Combining Dice with cross-entropy is often a strong practical choice when classes are imbalanced.

Course illustration
Course illustration

All Rights Reserved.