Shortcuts

Source code for mmdet3d.models.losses.multibin_loss

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import nn as nn
from torch.nn import functional as F

from mmdet.models.losses.utils import weighted_loss
from ..builder import LOSSES


@weighted_loss
def multibin_loss(pred_orientations, gt_orientations, num_dir_bins=4):
    """Multi-Bin Loss.

    Args:
        pred_orientations(torch.Tensor): Predicted local vector
            orientation in [axis_cls, head_cls, sin, cos] format.
            shape (N, num_dir_bins * 4)
        gt_orientations(torch.Tensor): Corresponding gt bboxes,
            shape (N, num_dir_bins * 2).
        num_dir_bins(int, optional): Number of bins to encode
            direction angle.
            Defaults: 4.

    Return:
        torch.Tensor: Loss tensor.
    """
    cls_losses = 0
    reg_losses = 0
    reg_cnt = 0
    for i in range(num_dir_bins):
        # bin cls loss
        cls_ce_loss = F.cross_entropy(
            pred_orientations[:, (i * 2):(i * 2 + 2)],
            gt_orientations[:, i].long(),
            reduction='mean')
        # regression loss
        valid_mask_i = (gt_orientations[:, i] == 1)
        cls_losses += cls_ce_loss
        if valid_mask_i.sum() > 0:
            start = num_dir_bins * 2 + i * 2
            end = start + 2
            pred_offset = F.normalize(pred_orientations[valid_mask_i,
                                                        start:end])
            gt_offset_sin = torch.sin(gt_orientations[valid_mask_i,
                                                      num_dir_bins + i])
            gt_offset_cos = torch.cos(gt_orientations[valid_mask_i,
                                                      num_dir_bins + i])
            reg_loss = \
                F.l1_loss(pred_offset[:, 0], gt_offset_sin,
                          reduction='none') + \
                F.l1_loss(pred_offset[:, 1], gt_offset_cos,
                          reduction='none')

            reg_losses += reg_loss.sum()
            reg_cnt += valid_mask_i.sum()

        return cls_losses / num_dir_bins + reg_losses / reg_cnt


[docs]@LOSSES.register_module() class MultiBinLoss(nn.Module): """Multi-Bin Loss for orientation. Args: reduction (str, optional): The method to reduce the loss. Options are 'none', 'mean' and 'sum'. Defaults to 'none'. loss_weight (float, optional): The weight of loss. Defaults to 1.0. """ def __init__(self, reduction='none', loss_weight=1.0): super(MultiBinLoss, self).__init__() assert reduction in ['none', 'sum', 'mean'] self.reduction = reduction self.loss_weight = loss_weight
[docs] def forward(self, pred, target, num_dir_bins, reduction_override=None): """Forward function. Args: pred (torch.Tensor): The prediction. target (torch.Tensor): The learning target of the prediction. num_dir_bins (int): Number of bins to encode direction angle. reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None. """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) loss = self.loss_weight * multibin_loss( pred, target, num_dir_bins=num_dir_bins, reduction=reduction) return loss