TensorFlow
Object Detection
Data Balancing
Class Weights
Machine Learning

Class weights for balancing data in TensorFlow Object Detection API

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

Class imbalance is common in object detection datasets. A model may see thousands of examples of one class and only a few dozen of another, which often leads to poor recall on the rare classes.

Why class_weight Is Not a Drop-In Feature Here

In plain Keras classification, class_weight works because each training example usually has one label. TensorFlow Object Detection API models do not train that way.

A detector typically predicts:

  • box coordinates
  • objectness or class scores
  • many anchors or proposals per image
  • a large background class

That means the loss is built from many classification terms plus localization terms. Because of that structure, the API does not expose a simple per-class weighting switch like model.fit(..., class_weight=...).

What the API Usually Supports

The stock configuration files let you tune overall loss behavior, but not arbitrary per-class weights in a single obvious field. Depending on the model family, you may see options such as:

  • global classification versus localization loss weights
  • focal loss settings such as alpha and gamma
  • sampling and data augmentation choices

Those settings help with imbalance, but they are not the same as saying "class cat should count four times more than class dog."

Practical Ways to Rebalance Training

In most projects, you should try these options before patching framework internals:

1. Rebalance the Dataset

Oversample images containing rare classes, undersample very common classes, or add more examples for the minority classes. This is usually the most robust fix because it changes what the model sees, not just how the loss is scaled.

2. Use Focal Loss If the Model Supports It

Focal loss reduces the impact of easy negatives and helps with severe class imbalance, especially when background examples dominate.

3. Add Per-Class Weights in Custom Training Code

If you really need explicit class weighting, you usually have to modify the classification loss computation so each anchor or matched box is multiplied by a class-specific factor.

Example of Per-Class Weighting

The following TensorFlow example is not a full detector. It shows the core idea behind class weighting: compute the usual classification loss, gather a weight for each target class, and multiply before reducing.

python
1import tensorflow as tf
2
3# Three foreground classes with different importance.
4class_weights = tf.constant([1.0, 3.0, 5.0], dtype=tf.float32)
5
6# True class id for each matched training example.
7labels = tf.constant([0, 2, 1, 2], dtype=tf.int32)
8
9# Logits produced by a classifier head.
10logits = tf.constant([
11    [2.2, 0.1, -0.5],
12    [0.3, 0.2, 1.7],
13    [0.4, 1.3, 0.8],
14    [0.1, 0.6, 1.2],
15], dtype=tf.float32)
16
17base_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
18    labels=labels,
19    logits=logits,
20)
21
22weights = tf.gather(class_weights, labels)
23weighted_loss = tf.reduce_mean(base_loss * weights)
24
25print(base_loss.numpy())
26print(weights.numpy())
27print(float(weighted_loss.numpy()))

Inside the Object Detection API, the same idea would be applied to the detector's classification targets after matching ground-truth boxes to anchors or proposals. That usually means customizing the loss builder or the model code rather than changing only the config file.

Where People Get Confused

There are two especially common misunderstandings.

First, developers try to pass class_weight into a training script that does not call standard model.fit in the usual classification sense. Nothing happens because the Object Detection API uses its own input and loss pipeline.

Second, people weight images instead of object instances. In detection, one image can contain many boxes from different classes, so image-level weighting is often too coarse.

When Class Weights Are Worth It

Per-class weighting is useful when:

  • a minority class is important and costly to miss
  • you cannot gather more data quickly
  • focal loss and resampling are not enough

It is less useful when the minority class is simply under-annotated or mislabeled. In that case, weighting bad labels more heavily usually makes the model worse.

Common Pitfalls

Be careful with the background class. Many detector losses include a large number of background anchors, and weighting foreground classes without understanding the background term can destabilize training.

Do not assume that global classification_weight settings are per-class weights. They usually scale the whole classification loss term, not one label versus another.

Avoid extreme weight ratios. If one class is weighted too aggressively, the model may overfit that class and produce many false positives.

Finally, evaluate per-class precision and recall after every change. Weighted training that improves overall loss can still degrade the rare class you were trying to help.

Summary

  • TensorFlow Object Detection API does not provide a simple universal class_weight knob like basic Keras classification.
  • The most practical fixes are dataset rebalancing, focal loss, and careful sampling.
  • True per-class weighting usually requires customizing the detector's classification loss.
  • Weight boxes or anchors, not just whole images.
  • Measure per-class metrics after reweighting so you can see whether the change actually helped.

Course illustration
Course illustration

All Rights Reserved.