Source code for mmdet3d.models.losses.paconv_regularization_loss

import torch
from torch import nn as nn

from mmdet3d.ops import PAConv, PAConvCUDA
from mmdet.models.builder import LOSSES
from mmdet.models.losses.utils import weight_reduce_loss


def weight_correlation(conv):
    """Calculate correlations between kernel weights in Conv's weight bank as
    regularization loss. The cosine similarity is used as metrics.

    Args:
        conv (nn.Module): A Conv modules to be regularized.
            Currently we only support `PAConv` and `PAConvCUDA`.

    Returns:
        torch.Tensor: Correlations between each kernel weights in weight bank.
    """
    assert isinstance(conv, (PAConv, PAConvCUDA)), \
        f'unsupported module type {type(conv)}'
    kernels = conv.weight_bank  # [C_in, num_kernels * C_out]
    in_channels = conv.in_channels
    out_channels = conv.out_channels
    num_kernels = conv.num_kernels

    # [num_kernels, Cin * Cout]
    flatten_kernels = kernels.view(in_channels, num_kernels, out_channels).\
        permute(1, 0, 2).reshape(num_kernels, -1)
    # [num_kernels, num_kernels]
    inner_product = torch.matmul(flatten_kernels, flatten_kernels.T)
    # [num_kernels, 1]
    kernel_norms = torch.sum(flatten_kernels**2, dim=-1, keepdim=True)**0.5
    # [num_kernels, num_kernels]
    kernel_norms = torch.matmul(kernel_norms, kernel_norms.T)
    cosine_sims = inner_product / kernel_norms
    # take upper triangular part excluding diagonal since we only compute
    # correlation between different kernels once
    # the square is to ensure positive loss, refer to:
    # https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/tool/train.py#L208
    corr = torch.sum(torch.triu(cosine_sims, diagonal=1)**2)

    return corr


def paconv_regularization_loss(modules, reduction):
    """Computes correlation loss of PAConv weight kernels as regularization.

    Args:
        modules (List[nn.Module] | :obj:`generator`):
            A list or a python generator of torch.nn.Modules.
        reduction (str): Method to reduce losses among PAConv modules.
            The valid reduction method are none, sum or mean.

    Returns:
        torch.Tensor: Correlation loss of kernel weights.
    """
    corr_loss = []
    for module in modules:
        if isinstance(module, (PAConv, PAConvCUDA)):
            corr_loss.append(weight_correlation(module))
    corr_loss = torch.stack(corr_loss)

    # perform reduction
    corr_loss = weight_reduce_loss(corr_loss, reduction=reduction)

    return corr_loss


[docs]@LOSSES.register_module() class PAConvRegularizationLoss(nn.Module): """Calculate correlation loss of kernel weights in PAConv's weight bank. This is used as a regularization term in PAConv model training. Args: reduction (str): Method to reduce losses. The reduction is performed among all PAConv modules instead of prediction tensors. The valid reduction method are none, sum or mean. loss_weight (float, optional): Weight of loss. Defaults to 1.0. """ def __init__(self, reduction='mean', loss_weight=1.0): super(PAConvRegularizationLoss, self).__init__() assert reduction in ['none', 'sum', 'mean'] self.reduction = reduction self.loss_weight = loss_weight
[docs] def forward(self, modules, reduction_override=None, **kwargs): """Forward function of loss calculation. Args: modules (List[nn.Module] | :obj:`generator`): A list or a python generator of torch.nn.Modules. reduction_override (str, optional): Method to reduce losses. The valid reduction method are 'none', 'sum' or 'mean'. Defaults to None. Returns: torch.Tensor: Correlation loss of kernel weights. """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) return self.loss_weight * paconv_regularization_loss( modules, reduction=reduction)