Source code for mmdet3d.models.losses.axis_aligned_iou_loss

import torch
from torch import nn as nn

from mmdet.models.builder import LOSSES
from mmdet.models.losses.utils import weighted_loss
from ...core.bbox import AxisAlignedBboxOverlaps3D


[docs]@weighted_loss def axis_aligned_iou_loss(pred, target): """Calculate the IoU loss (1-IoU) of two set of axis aligned bounding boxes. Note that predictions and targets are one-to-one corresponded. Args: pred (torch.Tensor): Bbox predictions with shape [..., 3]. target (torch.Tensor): Bbox targets (gt) with shape [..., 3]. Returns: torch.Tensor: IoU loss between predictions and targets. """ axis_aligned_iou = AxisAlignedBboxOverlaps3D()( pred, target, is_aligned=True) iou_loss = 1 - axis_aligned_iou return iou_loss
[docs]@LOSSES.register_module() class AxisAlignedIoULoss(nn.Module): """Calculate the IoU loss (1-IoU) of axis aligned bounding boxes. Args: reduction (str): Method to reduce losses. The valid reduction method are none, sum or mean. loss_weight (float, optional): Weight of loss. Defaults to 1.0. """ def __init__(self, reduction='mean', loss_weight=1.0): super(AxisAlignedIoULoss, self).__init__() assert reduction in ['none', 'sum', 'mean'] self.reduction = reduction self.loss_weight = loss_weight
[docs] def forward(self, pred, target, weight=None, avg_factor=None, reduction_override=None, **kwargs): """Forward function of loss calculation. Args: pred (torch.Tensor): Bbox predictions with shape [..., 3]. target (torch.Tensor): Bbox targets (gt) with shape [..., 3]. weight (torch.Tensor|float, optional): Weight of loss. \ Defaults to None. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction_override (str, optional): Method to reduce losses. The valid reduction method are 'none', 'sum' or 'mean'. Defaults to None. Returns: torch.Tensor: IoU loss between predictions and targets. """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) if (weight is not None) and (not torch.any(weight > 0)) and ( reduction != 'none'): return (pred * weight).sum() return axis_aligned_iou_loss( pred, target, weight=weight, avg_factor=avg_factor, reduction=reduction) * self.loss_weight