TensorFlow
Keras
class_weight
fit function
machine learning

How does TensorFlow/Keras's class_weight parameter of the fit function work?

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

class_weight is Keras' built-in way to tell training that some classes should matter more than others. It is usually used for imbalanced classification, where a model can get a low loss by mostly predicting the majority class and ignoring the rare one.

What class_weight Changes

When you pass class_weight to model.fit, Keras multiplies the loss contribution of each training example by the weight associated with that example's class. A larger weight means a mistake on that class produces a larger loss and therefore a larger gradient update.

Conceptually, if the unweighted loss for one example is loss_i, the weighted version becomes weight_for_class * loss_i.

A typical dictionary looks like class_weight={0: 1.0, 1: 4.0}. In that case, errors on class 1 count four times as much as errors on class 0.

This does not duplicate examples in memory. It only changes how strongly each example influences optimization.

A Small Example

The example below builds a binary classifier on an imbalanced dataset. The labels contain many more zeros than ones, so we weight class 1 more heavily.

python
1import numpy as np
2import tensorflow as tf
3
4x = np.array([
5    [0.0, 0.0],
6    [0.1, 0.2],
7    [0.2, 0.1],
8    [0.3, 0.4],
9    [1.0, 1.1],
10    [1.2, 1.0],
11], dtype="float32")
12
13y = np.array([0, 0, 0, 0, 1, 1], dtype="int32")
14
15model = tf.keras.Sequential([
16    tf.keras.layers.Input(shape=(2,)),
17    tf.keras.layers.Dense(8, activation="relu"),
18    tf.keras.layers.Dense(1, activation="sigmoid"),
19])
20
21model.compile(
22    optimizer="adam",
23    loss="binary_crossentropy",
24    metrics=["accuracy"],
25)
26
27history = model.fit(
28    x,
29    y,
30    epochs=10,
31    verbose=0,
32    class_weight={0: 1.0, 1: 3.0},
33)
34
35print(history.history["loss"][-1])

The code is ordinary Keras training. The only difference is the weight mapping.

How to Choose the Weights

A common starting point is inverse frequency. If class 1 appears half as often as class 0, you can give class 1 roughly twice the weight. That is not a law; it is a baseline.

Here is a simple manual calculation.

python
1import numpy as np
2
3labels = np.array([0, 0, 0, 0, 1, 1])
4classes, counts = np.unique(labels, return_counts=True)
5total = counts.sum()
6
7weights = {
8    int(cls): total / (len(classes) * count)
9    for cls, count in zip(classes, counts)
10}
11
12print(weights)

In practice, you still tune the weights with validation metrics. If recall on the minority class is poor, increase its weight carefully and watch precision, calibration, and overall stability.

class_weight Versus sample_weight

class_weight applies one shared weight per class. sample_weight is more granular and lets you weight individual rows differently.

Use class_weight when the problem is class imbalance. Use sample_weight when some individual samples are more important, noisier, or should count less for domain-specific reasons.

If you train from iterator-style inputs, Keras can also receive explicit sample weights from the dataset itself.

Label Format Matters

Keras expects the keys in class_weight to be class indices. For sparse labels, that means integer class IDs such as 0, 1, and 2.

One detail from the TensorFlow API matters here: when targets have rank 2 or greater, labels must either be one-hot encoded or include an explicit final dimension of 1 for sparse labels. If the label shape does not match what the loss expects, weighting will not save the run; training will fail earlier because the target format is wrong.

Common Pitfalls

The biggest mistake is using class_weight and then evaluating only accuracy. Accuracy can still look fine while minority-class recall remains poor. Check precision, recall, F1, PR-AUC, or confusion matrices.

Another mistake is assigning extremely large weights. That can make training unstable and cause the model to overcorrect toward the minority class.

A third issue is using class weighting when resampling would be simpler. If the dataset is tiny, modest oversampling plus good validation can be easier to reason about.

Finally, do not confuse weighting with threshold tuning. class_weight changes training. Decision thresholds such as 0.5 versus 0.2 change inference. They solve related but different problems.

Summary

  • 'class_weight multiplies each example's loss according to its class.'
  • It helps when class imbalance causes the model to ignore rare classes.
  • A mapping such as class_weight={0: 1.0, 1: 3.0} makes class 1 errors count more.
  • Start with inverse-frequency weights, then tune using validation metrics.
  • Use sample_weight when per-example control is needed.
  • Measure minority-class performance directly instead of relying only on accuracy.

Course illustration
Course illustration

All Rights Reserved.