Shortcuts

Source code for mmdet3d.models.losses.multibin_loss

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import nn as nn
from torch.nn import functional as F

from mmdet.models.losses.utils import weighted_loss
from ..builder import LOSSES


@weighted_loss
def multibin_loss(pred_orientations, gt_orientations, num_dir_bins=4):
    """Multi-Bin Loss.

    Args:
        pred_orientations(torch.Tensor): Predicted local vector
            orientation in [axis_cls, head_cls, sin, cos] format.
            shape (N, num_dir_bins * 4)
        gt_orientations(torch.Tensor): Corresponding gt bboxes,
            shape (N, num_dir_bins * 2).
        num_dir_bins(int, optional): Number of bins to encode
            direction angle.
            Defaults: 4.

    Return:
        torch.Tensor: Loss tensor.
    """
    cls_losses = 0
    reg_losses = 0
    reg_cnt = 0
    for i in range(num_dir_bins):
        # bin cls loss
        cls_ce_loss = F.cross_entropy(
            pred_orientations[:, (i * 2):(i * 2 + 2)],
            gt_orientations[:, i].long(),
            reduction='mean')
        # regression loss
        valid_mask_i = (gt_orientations[:, i] == 1)
        cls_losses += cls_ce_loss
        if valid_mask_i.sum() > 0:
            start = num_dir_bins * 2 + i * 2
            end = start + 2
            pred_offset = F.normalize(pred_orientations[valid_mask_i,
                                                        start:end])
            gt_offset_sin = torch.sin(gt_orientations[valid_mask_i,
                                                      num_dir_bins + i])
            gt_offset_cos = torch.cos(gt_orientations[valid_mask_i,
                                                      num_dir_bins + i])
            reg_loss = \
                F.l1_loss(pred_offset[:, 0], gt_offset_sin,
                          reduction='none') + \
                F.l1_loss(pred_offset[:, 1], gt_offset_cos,
                          reduction='none')

            reg_losses += reg_loss.sum()
            reg_cnt += valid_mask_i.sum()

        return cls_losses / num_dir_bins + reg_losses / reg_cnt


[docs]@LOSSES.register_module() class MultiBinLoss(nn.Module): """Multi-Bin Loss for orientation. Args: reduction (str, optional): The method to reduce the loss. Options are 'none', 'mean' and 'sum'. Defaults to 'none'. loss_weight (float, optional): The weight of loss. Defaults to 1.0. """ def __init__(self, reduction='none', loss_weight=1.0): super(MultiBinLoss, self).__init__() assert reduction in ['none', 'sum', 'mean'] self.reduction = reduction self.loss_weight = loss_weight
[docs] def forward(self, pred, target, num_dir_bins, reduction_override=None): """Forward function. Args: pred (torch.Tensor): The prediction. target (torch.Tensor): The learning target of the prediction. num_dir_bins (int): Number of bins to encode direction angle. reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None. """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) loss = self.loss_weight * multibin_loss( pred, target, num_dir_bins=num_dir_bins, reduction=reduction) return loss
Read the Docs v: dev
Versions
latest
stable
v1.0.0rc1
v1.0.0rc0
v0.18.1
v0.18.0
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.