Source code for mmdet3d.models.losses.chamfer_distance

import torch
from torch import nn as nn
from torch.nn.functional import l1_loss, mse_loss, smooth_l1_loss

from mmdet.models.builder import LOSSES


[docs]def chamfer_distance(src, dst, src_weight=1.0, dst_weight=1.0, criterion_mode='l2', reduction='mean'): """Calculate Chamfer Distance of two sets. Args: src (torch.Tensor): Source set with shape [B, N, C] to calculate Chamfer Distance. dst (torch.Tensor): Destination set with shape [B, M, C] to calculate Chamfer Distance. src_weight (torch.Tensor or float): Weight of source loss. dst_weight (torch.Tensor or float): Weight of destination loss. criterion_mode (str): Criterion mode to calculate distance. The valid modes are smooth_l1, l1 or l2. reduction (str): Method to reduce losses. The valid reduction method are 'none', 'sum' or 'mean'. Returns: tuple: Source and Destination loss with the corresponding indices. - loss_src (torch.Tensor): The min distance \ from source to destination. - loss_dst (torch.Tensor): The min distance \ from destination to source. - indices1 (torch.Tensor): Index the min distance point \ for each point in source to destination. - indices2 (torch.Tensor): Index the min distance point \ for each point in destination to source. """ if criterion_mode == 'smooth_l1': criterion = smooth_l1_loss elif criterion_mode == 'l1': criterion = l1_loss elif criterion_mode == 'l2': criterion = mse_loss else: raise NotImplementedError src_expand = src.unsqueeze(2).repeat(1, 1, dst.shape[1], 1) dst_expand = dst.unsqueeze(1).repeat(1, src.shape[1], 1, 1) distance = criterion(src_expand, dst_expand, reduction='none').sum(-1) src2dst_distance, indices1 = torch.min(distance, dim=2) # (B,N) dst2src_distance, indices2 = torch.min(distance, dim=1) # (B,M) loss_src = (src2dst_distance * src_weight) loss_dst = (dst2src_distance * dst_weight) if reduction == 'sum': loss_src = torch.sum(loss_src) loss_dst = torch.sum(loss_dst) elif reduction == 'mean': loss_src = torch.mean(loss_src) loss_dst = torch.mean(loss_dst) elif reduction == 'none': pass else: raise NotImplementedError return loss_src, loss_dst, indices1, indices2
[docs]@LOSSES.register_module() class ChamferDistance(nn.Module): """Calculate Chamfer Distance of two sets. Args: mode (str): Criterion mode to calculate distance. The valid modes are smooth_l1, l1 or l2. reduction (str): Method to reduce losses. The valid reduction method are none, sum or mean. loss_src_weight (float): Weight of loss_source. loss_dst_weight (float): Weight of loss_target. """ def __init__(self, mode='l2', reduction='mean', loss_src_weight=1.0, loss_dst_weight=1.0): super(ChamferDistance, self).__init__() assert mode in ['smooth_l1', 'l1', 'l2'] assert reduction in ['none', 'sum', 'mean'] self.mode = mode self.reduction = reduction self.loss_src_weight = loss_src_weight self.loss_dst_weight = loss_dst_weight
[docs] def forward(self, source, target, src_weight=1.0, dst_weight=1.0, reduction_override=None, return_indices=False, **kwargs): """Forward function of loss calculation. Args: source (torch.Tensor): Source set with shape [B, N, C] to calculate Chamfer Distance. target (torch.Tensor): Destination set with shape [B, M, C] to calculate Chamfer Distance. src_weight (torch.Tensor | float, optional): Weight of source loss. Defaults to 1.0. dst_weight (torch.Tensor | float, optional): Weight of destination loss. Defaults to 1.0. reduction_override (str, optional): Method to reduce losses. The valid reduction method are 'none', 'sum' or 'mean'. Defaults to None. return_indices (bool, optional): Whether to return indices. Defaults to False. Returns: tuple[torch.Tensor]: If ``return_indices=True``, return losses of \ source and target with their corresponding indices in the \ order of ``(loss_source, loss_target, indices1, indices2)``. \ If ``return_indices=False``, return \ ``(loss_source, loss_target)``. """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) loss_source, loss_target, indices1, indices2 = chamfer_distance( source, target, src_weight, dst_weight, self.mode, reduction) loss_source *= self.loss_src_weight loss_target *= self.loss_dst_weight if return_indices: return loss_source, loss_target, indices1, indices2 else: return loss_source, loss_target