import torch
from torch import nn as nn
from mmdet.models.builder import LOSSES
from mmdet.models.losses.utils import weighted_loss
from ...core.bbox import AxisAlignedBboxOverlaps3D
[docs]@weighted_loss
def axis_aligned_iou_loss(pred, target):
"""Calculate the IoU loss (1-IoU) of two set of axis aligned bounding
boxes. Note that predictions and targets are one-to-one corresponded.
Args:
pred (torch.Tensor): Bbox predictions with shape [..., 3].
target (torch.Tensor): Bbox targets (gt) with shape [..., 3].
Returns:
torch.Tensor: IoU loss between predictions and targets.
"""
axis_aligned_iou = AxisAlignedBboxOverlaps3D()(
pred, target, is_aligned=True)
iou_loss = 1 - axis_aligned_iou
return iou_loss
[docs]@LOSSES.register_module()
class AxisAlignedIoULoss(nn.Module):
"""Calculate the IoU loss (1-IoU) of axis aligned bounding boxes.
Args:
reduction (str): Method to reduce losses.
The valid reduction method are none, sum or mean.
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
"""
def __init__(self, reduction='mean', loss_weight=1.0):
super(AxisAlignedIoULoss, self).__init__()
assert reduction in ['none', 'sum', 'mean']
self.reduction = reduction
self.loss_weight = loss_weight
[docs] def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
"""Forward function of loss calculation.
Args:
pred (torch.Tensor): Bbox predictions with shape [..., 3].
target (torch.Tensor): Bbox targets (gt) with shape [..., 3].
weight (torch.Tensor|float, optional): Weight of loss. \
Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): Method to reduce losses.
The valid reduction method are 'none', 'sum' or 'mean'.
Defaults to None.
Returns:
torch.Tensor: IoU loss between predictions and targets.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if (weight is not None) and (not torch.any(weight > 0)) and (
reduction != 'none'):
return (pred * weight).sum()
return axis_aligned_iou_loss(
pred,
target,
weight=weight,
avg_factor=avg_factor,
reduction=reduction) * self.loss_weight