Introduction
Weighted binary crossentropy assigns different penalties to false positives and false negatives during training. This is essential for imbalanced datasets where one class is much rarer than the other — for example, fraud detection (1% fraud, 99% legitimate) or medical diagnosis (rare diseases). Without weighting, the model learns to predict the majority class and ignores the minority class.
Standard Binary Crossentropy
The standard (unweighted) loss treats both classes equally:
1import tensorflow as tf
2
3# Standard binary crossentropy
4loss_fn = tf.keras.losses.BinaryCrossentropy()
5
6y_true = [1, 0, 1, 0]
7y_pred = [0.9, 0.1, 0.8, 0.2]
8loss = loss_fn(y_true, y_pred)
9print(f"Loss: {loss.numpy():.4f}")
Method 1: class_weight in model.fit()
The simplest approach — pass class weights directly to fit():
1import numpy as np
2from tensorflow import keras
3
4# Count classes
5n_positive = np.sum(y_train == 1)
6n_negative = np.sum(y_train == 0)
7total = len(y_train)
8
9# Compute weights inversely proportional to class frequency
10class_weight = {
11 0: total / (2 * n_negative),
12 1: total / (2 * n_positive)
13}
14
15print(f"Class weights: {class_weight}")
16# For 95/5 split: {0: 0.526, 1: 10.0}
17
18model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
19model.fit(X_train, y_train, class_weight=class_weight, epochs=10)
This multiplies the loss of each sample by its class weight during backpropagation.
Method 2: Custom Weighted Loss Function
For more control, define a custom loss:
1import tensorflow.keras.backend as K
2
3def weighted_binary_crossentropy(pos_weight):
4 """
5 pos_weight: weight for positive class (class 1)
6 Higher pos_weight = more penalty for missing positives (false negatives)
7 """
8 def loss(y_true, y_pred):
9 y_pred = K.clip(y_pred, K.epsilon(), 1 - K.epsilon())
10
11 # Standard binary crossentropy with class weights
12 loss = -(pos_weight * y_true * K.log(y_pred) +
13 (1 - y_true) * K.log(1 - y_pred))
14
15 return K.mean(loss)
16
17 return loss
18
19# Usage: positive class is 20x rarer
20model.compile(
21 optimizer='adam',
22 loss=weighted_binary_crossentropy(pos_weight=20.0),
23 metrics=['accuracy']
24)
Method 3: Using sample_weight
Apply per-sample weights for fine-grained control:
1# Compute sample weights
2sample_weights = np.where(y_train == 1, 10.0, 1.0)
3
4model.fit(
5 X_train, y_train,
6 sample_weight=sample_weights,
7 epochs=10
8)
This is useful when different samples within the same class should have different weights (e.g., based on confidence or importance).
Method 4: tf.nn.weighted_cross_entropy_with_logits
TensorFlow provides a built-in weighted crossentropy that works on logits (pre-sigmoid values):
1import tensorflow as tf
2
3def weighted_loss(pos_weight):
4 def loss(y_true, y_pred):
5 return tf.reduce_mean(
6 tf.nn.weighted_cross_entropy_with_logits(
7 labels=y_true,
8 logits=y_pred,
9 pos_weight=pos_weight
10 )
11 )
12 return loss
13
14# Important: remove the sigmoid from the last layer when using logits
15model = keras.Sequential([
16 keras.layers.Dense(64, activation='relu', input_shape=(n_features,)),
17 keras.layers.Dense(32, activation='relu'),
18 keras.layers.Dense(1) # No sigmoid — outputs logits
19])
20
21model.compile(optimizer='adam', loss=weighted_loss(pos_weight=20.0))
This is numerically more stable than applying sigmoid first then computing log.
Method 5: Focal Loss
Focal loss down-weights easy examples and focuses on hard ones — useful for extreme imbalance:
1def focal_loss(gamma=2.0, alpha=0.25):
2 """
3 gamma: focusing parameter (higher = more focus on hard examples)
4 alpha: weight for positive class
5 """
6 def loss(y_true, y_pred):
7 y_pred = K.clip(y_pred, K.epsilon(), 1 - K.epsilon())
8
9 # Cross entropy
10 ce = -(y_true * K.log(y_pred) + (1 - y_true) * K.log(1 - y_pred))
11
12 # Focal weight
13 p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)
14 focal_weight = K.pow(1 - p_t, gamma)
15
16 # Alpha weight
17 alpha_weight = y_true * alpha + (1 - y_true) * (1 - alpha)
18
19 return K.mean(alpha_weight * focal_weight * ce)
20
21 return loss
22
23model.compile(optimizer='adam', loss=focal_loss(gamma=2.0, alpha=0.75))
Choosing the Right Weight
1# Method 1: Inverse frequency
2pos_weight = n_negative / n_positive
3
4# Method 2: Sklearn utility
5from sklearn.utils.class_weight import compute_class_weight
6weights = compute_class_weight('balanced', classes=np.array([0, 1]), y=y_train)
7class_weight = {0: weights[0], 1: weights[1]}
8
9# Method 3: Manual tuning (start with inverse frequency, adjust)
10# If too many false positives: decrease pos_weight
11# If too many false negatives: increase pos_weight
| Imbalance Ratio | Suggested pos_weight | Notes |
| 1:2 | 2.0 | Mild imbalance |
| 1:10 | 10.0 | Moderate imbalance |
| 1:100 | 50-100 | Severe — consider focal loss |
| 1:1000 | 100-500 | Extreme — combine with resampling |
Evaluation Metrics for Imbalanced Data
Accuracy is misleading with imbalanced data. Use these instead:
1from sklearn.metrics import classification_report, roc_auc_score
2
3y_pred_proba = model.predict(X_test)
4y_pred = (y_pred_proba > 0.5).astype(int)
5
6print(classification_report(y_test, y_pred))
7print(f"ROC AUC: {roc_auc_score(y_test, y_pred_proba):.4f}")
8
9# Adjust threshold for better recall
10y_pred_low_threshold = (y_pred_proba > 0.3).astype(int) # More sensitive
Common Pitfalls
Using accuracy as a metric: A model that predicts all-negative achieves 99% accuracy on a 1:100 dataset. Use precision, recall, F1, and AUC instead.
Too high pos_weight: Overweighting the positive class causes too many false positives. Start with the inverse class ratio and tune down if precision drops.
Sigmoid with logits loss: tf.nn.weighted_cross_entropy_with_logits expects raw logits. If your model has a sigmoid output layer, the loss computation is wrong (double sigmoid). Remove the final sigmoid when using logits-based losses.
class_weight vs sample_weight: class_weight applies the same weight to all samples of a class. sample_weight allows per-sample weights. They cannot be used simultaneously in model.fit().
Ignoring threshold tuning: The default 0.5 classification threshold is often suboptimal for imbalanced data. Use precision-recall curves to find the best threshold for your use case.
Summary
Use class_weight in model.fit() for the simplest weighted binary crossentropy approach
Define a custom loss function for more control over the weighting formula
Use tf.nn.weighted_cross_entropy_with_logits for numerically stable computation on logits
Consider focal loss for extreme class imbalance
Evaluate with precision, recall, F1, and AUC — never accuracy alone