Shortcuts

Source code for mmdet3d.models.losses.rotated_iou_loss

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.ops import diff_iou_rotated_3d
from torch import nn as nn

from mmdet.models.losses.utils import weighted_loss
from ..builder import LOSSES


@weighted_loss
def rotated_iou_3d_loss(pred, target):
    """Calculate the IoU loss (1-IoU) of two sets of rotated bounding boxes.
    Note that predictions and targets are one-to-one corresponded.

    Args:
        pred (torch.Tensor): Bbox predictions with shape [N, 7]
            (x, y, z, w, l, h, alpha).
        target (torch.Tensor): Bbox targets (gt) with shape [N, 7]
            (x, y, z, w, l, h, alpha).

    Returns:
        torch.Tensor: IoU loss between predictions and targets.
    """
    iou_loss = 1 - diff_iou_rotated_3d(pred.unsqueeze(0),
                                       target.unsqueeze(0))[0]
    return iou_loss


[docs]@LOSSES.register_module() class RotatedIoU3DLoss(nn.Module): """Calculate the IoU loss (1-IoU) of rotated bounding boxes. Args: reduction (str): Method to reduce losses. The valid reduction method are none, sum or mean. loss_weight (float, optional): Weight of loss. Defaults to 1.0. """ def __init__(self, reduction='mean', loss_weight=1.0): super().__init__() self.reduction = reduction self.loss_weight = loss_weight
[docs] def forward(self, pred, target, weight=None, avg_factor=None, reduction_override=None, **kwargs): """Forward function of loss calculation. Args: pred (torch.Tensor): Bbox predictions with shape [..., 7] (x, y, z, w, l, h, alpha). target (torch.Tensor): Bbox targets (gt) with shape [..., 7] (x, y, z, w, l, h, alpha). weight (torch.Tensor | float, optional): Weight of loss. Defaults to None. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction_override (str, optional): Method to reduce losses. The valid reduction method are 'none', 'sum' or 'mean'. Defaults to None. Returns: torch.Tensor: IoU loss between predictions and targets. """ if weight is not None and not torch.any(weight > 0): return pred.sum() * weight.sum() # 0 assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) if weight is not None and weight.dim() > 1: weight = weight.mean(-1) loss = self.loss_weight * rotated_iou_3d_loss( pred, target, weight, reduction=reduction, avg_factor=avg_factor, **kwargs) return loss
Read the Docs v: dev
Versions
latest
stable
v1.0.0rc1
v1.0.0rc0
v0.18.1
v0.18.0
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.