Shortcuts

Source code for mmdet3d.models.losses.chamfer_distance

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import nn as nn
from torch.nn.functional import l1_loss, mse_loss, smooth_l1_loss

from ..builder import LOSSES


[docs]def chamfer_distance(src, dst, src_weight=1.0, dst_weight=1.0, criterion_mode='l2', reduction='mean'): """Calculate Chamfer Distance of two sets. Args: src (torch.Tensor): Source set with shape [B, N, C] to calculate Chamfer Distance. dst (torch.Tensor): Destination set with shape [B, M, C] to calculate Chamfer Distance. src_weight (torch.Tensor or float): Weight of source loss. dst_weight (torch.Tensor or float): Weight of destination loss. criterion_mode (str): Criterion mode to calculate distance. The valid modes are smooth_l1, l1 or l2. reduction (str): Method to reduce losses. The valid reduction method are 'none', 'sum' or 'mean'. Returns: tuple: Source and Destination loss with the corresponding indices. - loss_src (torch.Tensor): The min distance from source to destination. - loss_dst (torch.Tensor): The min distance from destination to source. - indices1 (torch.Tensor): Index the min distance point for each point in source to destination. - indices2 (torch.Tensor): Index the min distance point for each point in destination to source. """ if criterion_mode == 'smooth_l1': criterion = smooth_l1_loss elif criterion_mode == 'l1': criterion = l1_loss elif criterion_mode == 'l2': criterion = mse_loss else: raise NotImplementedError src_expand = src.unsqueeze(2).repeat(1, 1, dst.shape[1], 1) dst_expand = dst.unsqueeze(1).repeat(1, src.shape[1], 1, 1) distance = criterion(src_expand, dst_expand, reduction='none').sum(-1) src2dst_distance, indices1 = torch.min(distance, dim=2) # (B,N) dst2src_distance, indices2 = torch.min(distance, dim=1) # (B,M) loss_src = (src2dst_distance * src_weight) loss_dst = (dst2src_distance * dst_weight) if reduction == 'sum': loss_src = torch.sum(loss_src) loss_dst = torch.sum(loss_dst) elif reduction == 'mean': loss_src = torch.mean(loss_src) loss_dst = torch.mean(loss_dst) elif reduction == 'none': pass else: raise NotImplementedError return loss_src, loss_dst, indices1, indices2
[docs]@LOSSES.register_module() class ChamferDistance(nn.Module): """Calculate Chamfer Distance of two sets. Args: mode (str): Criterion mode to calculate distance. The valid modes are smooth_l1, l1 or l2. reduction (str): Method to reduce losses. The valid reduction method are none, sum or mean. loss_src_weight (float): Weight of loss_source. loss_dst_weight (float): Weight of loss_target. """ def __init__(self, mode='l2', reduction='mean', loss_src_weight=1.0, loss_dst_weight=1.0): super(ChamferDistance, self).__init__() assert mode in ['smooth_l1', 'l1', 'l2'] assert reduction in ['none', 'sum', 'mean'] self.mode = mode self.reduction = reduction self.loss_src_weight = loss_src_weight self.loss_dst_weight = loss_dst_weight
[docs] def forward(self, source, target, src_weight=1.0, dst_weight=1.0, reduction_override=None, return_indices=False, **kwargs): """Forward function of loss calculation. Args: source (torch.Tensor): Source set with shape [B, N, C] to calculate Chamfer Distance. target (torch.Tensor): Destination set with shape [B, M, C] to calculate Chamfer Distance. src_weight (torch.Tensor | float, optional): Weight of source loss. Defaults to 1.0. dst_weight (torch.Tensor | float, optional): Weight of destination loss. Defaults to 1.0. reduction_override (str, optional): Method to reduce losses. The valid reduction method are 'none', 'sum' or 'mean'. Defaults to None. return_indices (bool, optional): Whether to return indices. Defaults to False. Returns: tuple[torch.Tensor]: If ``return_indices=True``, return losses of source and target with their corresponding indices in the order of ``(loss_source, loss_target, indices1, indices2)``. If ``return_indices=False``, return ``(loss_source, loss_target)``. """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) loss_source, loss_target, indices1, indices2 = chamfer_distance( source, target, src_weight, dst_weight, self.mode, reduction) loss_source *= self.loss_src_weight loss_target *= self.loss_dst_weight if return_indices: return loss_source, loss_target, indices1, indices2 else: return loss_source, loss_target
Read the Docs v: dev
Versions
latest
stable
v1.0.0rc1
v1.0.0rc0
v0.18.1
v0.18.0
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.