Keras
machine learning
binary crossentropy
neural networks
deep learning

Keras weighted binary crossentropy

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

Weighted binary crossentropy assigns different penalties to false positives and false negatives during training. This is essential for imbalanced datasets where one class is much rarer than the other — for example, fraud detection (1% fraud, 99% legitimate) or medical diagnosis (rare diseases). Without weighting, the model learns to predict the majority class and ignores the minority class.

Standard Binary Crossentropy

The standard (unweighted) loss treats both classes equally:

python
1import tensorflow as tf
2
3# Standard binary crossentropy
4loss_fn = tf.keras.losses.BinaryCrossentropy()
5
6y_true = [1, 0, 1, 0]
7y_pred = [0.9, 0.1, 0.8, 0.2]
8loss = loss_fn(y_true, y_pred)
9print(f"Loss: {loss.numpy():.4f}")

Method 1: class_weight in model.fit()

The simplest approach — pass class weights directly to fit():

python
1import numpy as np
2from tensorflow import keras
3
4# Count classes
5n_positive = np.sum(y_train == 1)
6n_negative = np.sum(y_train == 0)
7total = len(y_train)
8
9# Compute weights inversely proportional to class frequency
10class_weight = {
11    0: total / (2 * n_negative),
12    1: total / (2 * n_positive)
13}
14
15print(f"Class weights: {class_weight}")
16# For 95/5 split: {0: 0.526, 1: 10.0}
17
18model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
19model.fit(X_train, y_train, class_weight=class_weight, epochs=10)

This multiplies the loss of each sample by its class weight during backpropagation.

Method 2: Custom Weighted Loss Function

For more control, define a custom loss:

python
1import tensorflow.keras.backend as K
2
3def weighted_binary_crossentropy(pos_weight):
4    """
5    pos_weight: weight for positive class (class 1)
6    Higher pos_weight = more penalty for missing positives (false negatives)
7    """
8    def loss(y_true, y_pred):
9        y_pred = K.clip(y_pred, K.epsilon(), 1 - K.epsilon())
10
11        # Standard binary crossentropy with class weights
12        loss = -(pos_weight * y_true * K.log(y_pred) +
13                 (1 - y_true) * K.log(1 - y_pred))
14
15        return K.mean(loss)
16
17    return loss
18
19# Usage: positive class is 20x rarer
20model.compile(
21    optimizer='adam',
22    loss=weighted_binary_crossentropy(pos_weight=20.0),
23    metrics=['accuracy']
24)

Method 3: Using sample_weight

Apply per-sample weights for fine-grained control:

python
1# Compute sample weights
2sample_weights = np.where(y_train == 1, 10.0, 1.0)
3
4model.fit(
5    X_train, y_train,
6    sample_weight=sample_weights,
7    epochs=10
8)

This is useful when different samples within the same class should have different weights (e.g., based on confidence or importance).

Method 4: tf.nn.weighted_cross_entropy_with_logits

TensorFlow provides a built-in weighted crossentropy that works on logits (pre-sigmoid values):

python
1import tensorflow as tf
2
3def weighted_loss(pos_weight):
4    def loss(y_true, y_pred):
5        return tf.reduce_mean(
6            tf.nn.weighted_cross_entropy_with_logits(
7                labels=y_true,
8                logits=y_pred,
9                pos_weight=pos_weight
10            )
11        )
12    return loss
13
14# Important: remove the sigmoid from the last layer when using logits
15model = keras.Sequential([
16    keras.layers.Dense(64, activation='relu', input_shape=(n_features,)),
17    keras.layers.Dense(32, activation='relu'),
18    keras.layers.Dense(1)  # No sigmoid — outputs logits
19])
20
21model.compile(optimizer='adam', loss=weighted_loss(pos_weight=20.0))

This is numerically more stable than applying sigmoid first then computing log.

Method 5: Focal Loss

Focal loss down-weights easy examples and focuses on hard ones — useful for extreme imbalance:

python
1def focal_loss(gamma=2.0, alpha=0.25):
2    """
3    gamma: focusing parameter (higher = more focus on hard examples)
4    alpha: weight for positive class
5    """
6    def loss(y_true, y_pred):
7        y_pred = K.clip(y_pred, K.epsilon(), 1 - K.epsilon())
8
9        # Cross entropy
10        ce = -(y_true * K.log(y_pred) + (1 - y_true) * K.log(1 - y_pred))
11
12        # Focal weight
13        p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)
14        focal_weight = K.pow(1 - p_t, gamma)
15
16        # Alpha weight
17        alpha_weight = y_true * alpha + (1 - y_true) * (1 - alpha)
18
19        return K.mean(alpha_weight * focal_weight * ce)
20
21    return loss
22
23model.compile(optimizer='adam', loss=focal_loss(gamma=2.0, alpha=0.75))

Choosing the Right Weight

python
1# Method 1: Inverse frequency
2pos_weight = n_negative / n_positive
3
4# Method 2: Sklearn utility
5from sklearn.utils.class_weight import compute_class_weight
6weights = compute_class_weight('balanced', classes=np.array([0, 1]), y=y_train)
7class_weight = {0: weights[0], 1: weights[1]}
8
9# Method 3: Manual tuning (start with inverse frequency, adjust)
10# If too many false positives: decrease pos_weight
11# If too many false negatives: increase pos_weight
Imbalance RatioSuggested pos_weightNotes
1:22.0Mild imbalance
1:1010.0Moderate imbalance
1:10050-100Severe — consider focal loss
1:1000100-500Extreme — combine with resampling

Evaluation Metrics for Imbalanced Data

Accuracy is misleading with imbalanced data. Use these instead:

python
1from sklearn.metrics import classification_report, roc_auc_score
2
3y_pred_proba = model.predict(X_test)
4y_pred = (y_pred_proba > 0.5).astype(int)
5
6print(classification_report(y_test, y_pred))
7print(f"ROC AUC: {roc_auc_score(y_test, y_pred_proba):.4f}")
8
9# Adjust threshold for better recall
10y_pred_low_threshold = (y_pred_proba > 0.3).astype(int)  # More sensitive

Common Pitfalls

  • Using accuracy as a metric: A model that predicts all-negative achieves 99% accuracy on a 1:100 dataset. Use precision, recall, F1, and AUC instead.
  • Too high pos_weight: Overweighting the positive class causes too many false positives. Start with the inverse class ratio and tune down if precision drops.
  • Sigmoid with logits loss: tf.nn.weighted_cross_entropy_with_logits expects raw logits. If your model has a sigmoid output layer, the loss computation is wrong (double sigmoid). Remove the final sigmoid when using logits-based losses.
  • class_weight vs sample_weight: class_weight applies the same weight to all samples of a class. sample_weight allows per-sample weights. They cannot be used simultaneously in model.fit().
  • Ignoring threshold tuning: The default 0.5 classification threshold is often suboptimal for imbalanced data. Use precision-recall curves to find the best threshold for your use case.

Summary

  • Use class_weight in model.fit() for the simplest weighted binary crossentropy approach
  • Define a custom loss function for more control over the weighting formula
  • Use tf.nn.weighted_cross_entropy_with_logits for numerically stable computation on logits
  • Consider focal loss for extreme class imbalance
  • Evaluate with precision, recall, F1, and AUC — never accuracy alone

Course illustration
Course illustration

All Rights Reserved.