PyTorch
CrossEntropyLoss
machine learning
deep learning
neural networks

CrossEntropyLoss in PyTorch

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Explaining CrossEntropyLoss() in PyTorch

The CrossEntropyLoss() function in PyTorch is a crucial component for training classification models. It is an implementation of the cross-entropy loss, which is widely used in classification problems. This function combines nn.LogSoftmax() and nn.NLLLoss() in one single class. The loss functions in PyTorch, including CrossEntropyLoss(), are available in the torch.nn module.

Technical Explanation

Cross-Entropy Loss Definition

The cross-entropy loss measures the dissimilarity between two probability distributions. In a supervised learning context, these distributions are the true labels and the predicted probabilities.

For a single sample, the cross-entropy loss is given by:

latex
L(x, y) = -\sum_{c=1}^{C} y_c \cdot \log(\hat{p}_c)

Where:

  • ycy_c is the binary indicator (0 or 1) if class label cc is the correct classification for observation xx.
  • p^c\hat{p}_c is the predicted probability observation xx is of class cc.

Implementation in PyTorch

The CrossEntropyLoss() function expects the input to be raw, unnormalized scores for each class (also known as logits). Internally, the function applies softmax to these scores to obtain probabilities and then computes the loss.

python
1import torch
2import torch.nn as nn
3
4# Example: Batch size = 3, number of classes = 5
5logits = torch.tensor([[1.2, 0.3, 4.0, 2.3, 0.5],
6                       [2.5, 1.1, 3.5, 0.2, 1.5],
7                       [0.1, 2.0, 2.6, 1.9, 0.8]])
8
9# True labels
10targets = torch.tensor([2, 0, 1])
11
12loss_function = nn.CrossEntropyLoss()
13loss = loss_function(logits, targets)
14
15print('Cross-Entropy Loss:', loss.item())

Understanding the Parameters

  1. Input: The raw scores for each class, typically the output from the final layer of a model.
  2. Target: The correct class index for each sample, where each index is in the range [0,C1][0, C-1], with CC being the number of classes.
  3. Weight: Optional parameter representing a manual rescaling weight given to each class. If not provided, all classes are presumed to have equal weight.
  4. Reduction: The method to reduce the loss. Options include:
    • 'none': No reduction, returns a loss per element.
    • 'mean': Returns the mean of the losses.
    • 'sum': Returns the sum of the losses.

Key Points Summary

ParameterDescriptionDefault
inputRaw scores/ logits from the model.Required
targetCorrect class indices.Required
weightRescaling weight for each class.None
reductionSpecifies reduction method: none, mean, sum.'mean'

Why Use Cross-Entropy Loss?

The cross-entropy loss is preferred for several reasons when dealing with multi-class classification problems:

  • Probabilistic Interpretation: It explicitly measures the distance between true distributions (one-hot encoded labels) and predicted distributions (from your model).
  • Log Likelihood: It is equivalent to maximizing the log-likelihood of the data.
  • Differentiable: The function is smooth and differentiable, which is essential for gradient-based optimization methods.

Additional Topics and Scenarios

Weighted Loss

In scenarios with class imbalance, it is beneficial to specify weights to emphasize more on minority classes. An example is shown below.

python
1weights = torch.tensor([1.0, 1.0, 1.0, 2.0, 2.0])  # For hypothetical class imbalance
2
3loss_function = nn.CrossEntropyLoss(weight=weights)
4weighted_loss = loss_function(logits, targets)
5
6print('Weighted Cross-Entropy Loss:', weighted_loss.item())

Application in a Neural Network

A simple neural network utilizing CrossEntropyLoss() could be implemented as follows:

python
1import torch.nn as nn
2import torch.optim as optim
3
4class SimpleNN(nn.Module):
5    def __init__(self, input_size, num_classes):
6        super(SimpleNN, self).__init__()
7        self.fc = nn.Linear(input_size, num_classes)
8    
9    def forward(self, x):
10        return self.fc(x)
11
12input_size = 10  # Example input dimension
13num_classes = 5  # Example number of classes
14
15model = SimpleNN(input_size, num_classes)
16criterion = nn.CrossEntropyLoss()
17optimizer = optim.SGD(model.parameters(), lr=0.01)
18
19# Example multi-class input and corresponding target
20inputs = torch.randn(3, input_size)
21targets = torch.tensor([0, 2, 4])
22
23# Forward pass
24outputs = model(inputs)
25loss = criterion(outputs, targets)
26
27# Backward pass and optimization
28optimizer.zero_grad()
29loss.backward()
30optimizer.step()
31
32print('Training Step Loss:', loss.item())

The above example shows a training step using a simple fully connected network. The CrossEntropyLoss() is key in providing the measure of error, which is backpropagated throughout the network to update its weights.

In summary, CrossEntropyLoss() is a versatile and fundamental loss function for classification tasks in PyTorch, known for its simplicity, effectiveness, and straightforward integration into neural network training workflows.


Course illustration
Course illustration

All Rights Reserved.