Shortcuts

Source code for mmdet3d.models.losses.axis_aligned_iou_loss

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import nn as nn

from mmdet.models.losses.utils import weighted_loss
from ...core.bbox import AxisAlignedBboxOverlaps3D
from ..builder import LOSSES


[docs]@weighted_loss def axis_aligned_iou_loss(pred, target): """Calculate the IoU loss (1-IoU) of two sets of axis aligned bounding boxes. Note that predictions and targets are one-to-one corresponded. Args: pred (torch.Tensor): Bbox predictions with shape [..., 6] (x1, y1, z1, x2, y2, z2). target (torch.Tensor): Bbox targets (gt) with shape [..., 6] (x1, y1, z1, x2, y2, z2). Returns: torch.Tensor: IoU loss between predictions and targets. """ axis_aligned_iou = AxisAlignedBboxOverlaps3D()( pred, target, is_aligned=True) iou_loss = 1 - axis_aligned_iou return iou_loss
[docs]@LOSSES.register_module() class AxisAlignedIoULoss(nn.Module): """Calculate the IoU loss (1-IoU) of axis aligned bounding boxes. Args: reduction (str): Method to reduce losses. The valid reduction method are none, sum or mean. loss_weight (float, optional): Weight of loss. Defaults to 1.0. """ def __init__(self, reduction='mean', loss_weight=1.0): super(AxisAlignedIoULoss, self).__init__() assert reduction in ['none', 'sum', 'mean'] self.reduction = reduction self.loss_weight = loss_weight
[docs] def forward(self, pred, target, weight=None, avg_factor=None, reduction_override=None, **kwargs): """Forward function of loss calculation. Args: pred (torch.Tensor): Bbox predictions with shape [..., 6] (x1, y1, z1, x2, y2, z2). target (torch.Tensor): Bbox targets (gt) with shape [..., 6] (x1, y1, z1, x2, y2, z2). weight (torch.Tensor | float, optional): Weight of loss. Defaults to None. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction_override (str, optional): Method to reduce losses. The valid reduction method are 'none', 'sum' or 'mean'. Defaults to None. Returns: torch.Tensor: IoU loss between predictions and targets. """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) if (weight is not None) and (not torch.any(weight > 0)) and ( reduction != 'none'): return (pred * weight).sum() return axis_aligned_iou_loss( pred, target, weight=weight, avg_factor=avg_factor, reduction=reduction) * self.loss_weight
Read the Docs v: dev
Versions
latest
stable
v1.0.0rc1
v1.0.0rc0
v0.18.1
v0.18.0
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.