Source code for mmdet.models.losses.cross_entropy_loss

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES
from .utils import weight_reduce_loss


def cross_entropy(pred,
                  label,
                  weight=None,
                  reduction='mean',
                  avg_factor=None,
                  class_weight=None):
    """Calculate the CrossEntropy loss.

    Args:
        pred (torch.Tensor): The prediction with shape (N, C), C is the number
            of classes.
        label (torch.Tensor): The learning label of the prediction.
        weight (torch.Tensor, optional): Sample-wise loss weight.
        reduction (str, optional): The method used to reduce the loss.
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.
        class_weight (list[float], optional): The weight for each class.

    Returns:
        torch.Tensor: The calculated loss
    """
    # element-wise losses
    loss = F.cross_entropy(pred, label, weight=class_weight, reduction='none')

    # apply weights and do the reduction
    if weight is not None:
        weight = weight.float()
    loss = weight_reduce_loss(
        loss, weight=weight, reduction=reduction, avg_factor=avg_factor)

    return loss


def _expand_onehot_labels(labels, label_weights, label_channels):
    bin_labels = labels.new_full((labels.size(0), label_channels), 0)
    inds = torch.nonzero(
        (labels >= 0) & (labels < label_channels), as_tuple=False).squeeze()
    if inds.numel() > 0:
        bin_labels[inds, labels[inds]] = 1

    if label_weights is None:
        bin_label_weights = None
    else:
        bin_label_weights = label_weights.view(-1, 1).expand(
            label_weights.size(0), label_channels)

    return bin_labels, bin_label_weights


[docs]def binary_cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None, class_weight=None): """Calculate the binary CrossEntropy loss. Args: pred (torch.Tensor): The prediction with shape (N, 1). label (torch.Tensor): The learning label of the prediction. weight (torch.Tensor, optional): Sample-wise loss weight. reduction (str, optional): The method used to reduce the loss. Options are "none", "mean" and "sum". avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. class_weight (list[float], optional): The weight for each class. Returns: torch.Tensor: The calculated loss """ if pred.dim() != label.dim(): label, weight = _expand_onehot_labels(label, weight, pred.size(-1)) # weighted element-wise losses if weight is not None: weight = weight.float() loss = F.binary_cross_entropy_with_logits( pred, label.float(), pos_weight=class_weight, reduction='none') # do the reduction for the weighted loss loss = weight_reduce_loss( loss, weight, reduction=reduction, avg_factor=avg_factor) return loss
def mask_cross_entropy(pred, target, label, reduction='mean', avg_factor=None, class_weight=None): """Calculate the CrossEntropy loss for masks. Args: pred (torch.Tensor): The prediction with shape (N, C, *), C is the number of classes. The trailing * indicates arbitrary shape. target (torch.Tensor): The learning label of the prediction. label (torch.Tensor): ``label`` indicates the class label of the mask corresponding object. This will be used to select the mask in the of the class which the object belongs to when the mask prediction if not class-agnostic. reduction (str, optional): The method used to reduce the loss. Options are "none", "mean" and "sum". avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. class_weight (list[float], optional): The weight for each class. Returns: torch.Tensor: The calculated loss Example: >>> N, C = 3, 11 >>> H, W = 2, 2 >>> pred = torch.randn(N, C, H, W) * 1000 >>> target = torch.rand(N, H, W) >>> label = torch.randint(0, C, size=(N,)) >>> reduction = 'mean' >>> avg_factor = None >>> class_weights = None >>> loss = mask_cross_entropy(pred, target, label, reduction, >>> avg_factor, class_weights) >>> assert loss.shape == (1,) """ # TODO: handle these two reserved arguments assert reduction == 'mean' and avg_factor is None num_rois = pred.size()[0] inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) pred_slice = pred[inds, label].squeeze(1) return F.binary_cross_entropy_with_logits( pred_slice, target, weight=class_weight, reduction='mean')[None] @LOSSES.register_module() class CrossEntropyLoss(nn.Module): def __init__(self, use_sigmoid=False, use_mask=False, reduction='mean', class_weight=None, loss_weight=1.0): """CrossEntropyLoss. Args: use_sigmoid (bool, optional): Whether the prediction uses sigmoid of softmax. Defaults to False. use_mask (bool, optional): Whether to use mask cross entropy loss. Defaults to False. reduction (str, optional): . Defaults to 'mean'. Options are "none", "mean" and "sum". class_weight (list[float], optional): Weight of each class. Defaults to None. loss_weight (float, optional): Weight of the loss. Defaults to 1.0. """ super(CrossEntropyLoss, self).__init__() assert (use_sigmoid is False) or (use_mask is False) self.use_sigmoid = use_sigmoid self.use_mask = use_mask self.reduction = reduction self.loss_weight = loss_weight self.class_weight = class_weight if self.use_sigmoid: self.cls_criterion = binary_cross_entropy elif self.use_mask: self.cls_criterion = mask_cross_entropy else: self.cls_criterion = cross_entropy def forward(self, cls_score, label, weight=None, avg_factor=None, reduction_override=None, **kwargs): """Forward function. Args: cls_score (torch.Tensor): The prediction. label (torch.Tensor): The learning label of the prediction. weight (torch.Tensor, optional): Sample-wise loss weight. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction (str, optional): The method used to reduce the loss. Options are "none", "mean" and "sum". Returns: torch.Tensor: The calculated loss """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) if self.class_weight is not None: class_weight = cls_score.new_tensor( self.class_weight, device=cls_score.device) else: class_weight = None loss_cls = self.loss_weight * self.cls_criterion( cls_score, label, weight, class_weight=class_weight, reduction=reduction, avg_factor=avg_factor, **kwargs) return loss_cls