Source code for mmdet3d.models.fusion_layers.point_fusion

import torch
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from torch import nn as nn
from torch.nn import functional as F

from ..builder import FUSION_LAYERS
from . import apply_3d_transformation


def point_sample(
    img_meta,
    img_features,
    points,
    lidar2img_rt,
    img_scale_factor,
    img_crop_offset,
    img_flip,
    img_pad_shape,
    img_shape,
    aligned=True,
    padding_mode='zeros',
    align_corners=True,
):
    """Obtain image features using points.

    Args:
        img_meta (dict): Meta info.
        img_features (torch.Tensor): 1 x C x H x W image features.
        points (torch.Tensor): Nx3 point cloud in LiDAR coordinates.
        lidar2img_rt (torch.Tensor): 4x4 transformation matrix.
        img_scale_factor (torch.Tensor): Scale factor with shape of \
            (w_scale, h_scale).
        img_crop_offset (torch.Tensor): Crop offset used to crop \
            image during data augmentation with shape of (w_offset, h_offset).
        img_flip (bool): Whether the image is flipped.
        img_pad_shape (tuple[int]): int tuple indicates the h & w after
            padding, this is necessary to obtain features in feature map.
        img_shape (tuple[int]): int tuple indicates the h & w before padding
            after scaling, this is necessary for flipping coordinates.
        aligned (bool, optional): Whether use bilinear interpolation when
            sampling image features for each point. Defaults to True.
        padding_mode (str, optional): Padding mode when padding values for
            features of out-of-image points. Defaults to 'zeros'.
        align_corners (bool, optional): Whether to align corners when
            sampling image features for each point. Defaults to True.

    Returns:
        torch.Tensor: NxC image features sampled by point coordinates.
    """

    # apply transformation based on info in img_meta
    points = apply_3d_transformation(points, 'LIDAR', img_meta, reverse=True)

    # project points from velo coordinate to camera coordinate
    num_points = points.shape[0]
    pts_4d = torch.cat([points, points.new_ones(size=(num_points, 1))], dim=-1)
    pts_2d = pts_4d @ lidar2img_rt.t()

    # cam_points is Tensor of Nx4 whose last column is 1
    # transform camera coordinate to image coordinate

    pts_2d[:, 2] = torch.clamp(pts_2d[:, 2], min=1e-5)
    pts_2d[:, 0] /= pts_2d[:, 2]
    pts_2d[:, 1] /= pts_2d[:, 2]

    # img transformation: scale -> crop -> flip
    # the image is resized by img_scale_factor
    img_coors = pts_2d[:, 0:2] * img_scale_factor  # Nx2
    img_coors -= img_crop_offset

    # grid sample, the valid grid range should be in [-1,1]
    coor_x, coor_y = torch.split(img_coors, 1, dim=1)  # each is Nx1

    if img_flip:
        # by default we take it as horizontal flip
        # use img_shape before padding for flip
        orig_h, orig_w = img_shape
        coor_x = orig_w - coor_x

    h, w = img_pad_shape
    coor_y = coor_y / h * 2 - 1
    coor_x = coor_x / w * 2 - 1
    grid = torch.cat([coor_x, coor_y],
                     dim=1).unsqueeze(0).unsqueeze(0)  # Nx2 -> 1x1xNx2

    # align_corner=True provides higher performance
    mode = 'bilinear' if aligned else 'nearest'
    point_features = F.grid_sample(
        img_features,
        grid,
        mode=mode,
        padding_mode=padding_mode,
        align_corners=align_corners)  # 1xCx1xN feats

    return point_features.squeeze().t()


[docs]@FUSION_LAYERS.register_module() class PointFusion(BaseModule): """Fuse image features from multi-scale features. Args: img_channels (list[int] | int): Channels of image features. It could be a list if the input is multi-scale image features. pts_channels (int): Channels of point features mid_channels (int): Channels of middle layers out_channels (int): Channels of output fused features img_levels (int, optional): Number of image levels. Defaults to 3. conv_cfg (dict, optional): Dict config of conv layers of middle layers. Defaults to None. norm_cfg (dict, optional): Dict config of norm layers of middle layers. Defaults to None. act_cfg (dict, optional): Dict config of activatation layers. Defaults to None. activate_out (bool, optional): Whether to apply relu activation to output features. Defaults to True. fuse_out (bool, optional): Whether apply conv layer to the fused features. Defaults to False. dropout_ratio (int, float, optional): Dropout ratio of image features to prevent overfitting. Defaults to 0. aligned (bool, optional): Whether apply aligned feature fusion. Defaults to True. align_corners (bool, optional): Whether to align corner when sampling features according to points. Defaults to True. padding_mode (str, optional): Mode used to pad the features of points that do not have corresponding image features. Defaults to 'zeros'. lateral_conv (bool, optional): Whether to apply lateral convs to image features. Defaults to True. """ def __init__(self, img_channels, pts_channels, mid_channels, out_channels, img_levels=3, conv_cfg=None, norm_cfg=None, act_cfg=None, init_cfg=None, activate_out=True, fuse_out=False, dropout_ratio=0, aligned=True, align_corners=True, padding_mode='zeros', lateral_conv=True): super(PointFusion, self).__init__(init_cfg=init_cfg) if isinstance(img_levels, int): img_levels = [img_levels] if isinstance(img_channels, int): img_channels = [img_channels] * len(img_levels) assert isinstance(img_levels, list) assert isinstance(img_channels, list) assert len(img_channels) == len(img_levels) self.img_levels = img_levels self.act_cfg = act_cfg self.activate_out = activate_out self.fuse_out = fuse_out self.dropout_ratio = dropout_ratio self.img_channels = img_channels self.aligned = aligned self.align_corners = align_corners self.padding_mode = padding_mode self.lateral_convs = None if lateral_conv: self.lateral_convs = nn.ModuleList() for i in range(len(img_channels)): l_conv = ConvModule( img_channels[i], mid_channels, 3, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=self.act_cfg, inplace=False) self.lateral_convs.append(l_conv) self.img_transform = nn.Sequential( nn.Linear(mid_channels * len(img_channels), out_channels), nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01), ) else: self.img_transform = nn.Sequential( nn.Linear(sum(img_channels), out_channels), nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01), ) self.pts_transform = nn.Sequential( nn.Linear(pts_channels, out_channels), nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01), ) if self.fuse_out: self.fuse_conv = nn.Sequential( nn.Linear(mid_channels, out_channels), # For pts the BN is initialized differently by default # TODO: check whether this is necessary nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01), nn.ReLU(inplace=False)) if init_cfg is None: self.init_cfg = [ dict(type='Xavier', layer='Conv2d', distribution='uniform'), dict(type='Xavier', layer='Linear', distribution='uniform') ]
[docs] def forward(self, img_feats, pts, pts_feats, img_metas): """Forward function. Args: img_feats (list[torch.Tensor]): Image features. pts: [list[torch.Tensor]]: A batch of points with shape N x 3. pts_feats (torch.Tensor): A tensor consist of point features of the total batch. img_metas (list[dict]): Meta information of images. Returns: torch.Tensor: Fused features of each point. """ img_pts = self.obtain_mlvl_feats(img_feats, pts, img_metas) img_pre_fuse = self.img_transform(img_pts) if self.training and self.dropout_ratio > 0: img_pre_fuse = F.dropout(img_pre_fuse, self.dropout_ratio) pts_pre_fuse = self.pts_transform(pts_feats) fuse_out = img_pre_fuse + pts_pre_fuse if self.activate_out: fuse_out = F.relu(fuse_out) if self.fuse_out: fuse_out = self.fuse_conv(fuse_out) return fuse_out
[docs] def obtain_mlvl_feats(self, img_feats, pts, img_metas): """Obtain multi-level features for each point. Args: img_feats (list(torch.Tensor)): Multi-scale image features produced by image backbone in shape (N, C, H, W). pts (list[torch.Tensor]): Points of each sample. img_metas (list[dict]): Meta information for each sample. Returns: torch.Tensor: Corresponding image features of each point. """ if self.lateral_convs is not None: img_ins = [ lateral_conv(img_feats[i]) for i, lateral_conv in zip(self.img_levels, self.lateral_convs) ] else: img_ins = img_feats img_feats_per_point = [] # Sample multi-level features for i in range(len(img_metas)): mlvl_img_feats = [] for level in range(len(self.img_levels)): mlvl_img_feats.append( self.sample_single(img_ins[level][i:i + 1], pts[i][:, :3], img_metas[i])) mlvl_img_feats = torch.cat(mlvl_img_feats, dim=-1) img_feats_per_point.append(mlvl_img_feats) img_pts = torch.cat(img_feats_per_point, dim=0) return img_pts
[docs] def sample_single(self, img_feats, pts, img_meta): """Sample features from single level image feature map. Args: img_feats (torch.Tensor): Image feature map in shape (1, C, H, W). pts (torch.Tensor): Points of a single sample. img_meta (dict): Meta information of the single sample. Returns: torch.Tensor: Single level image features of each point. """ # TODO: image transformation also extracted img_scale_factor = ( pts.new_tensor(img_meta['scale_factor'][:2]) if 'scale_factor' in img_meta.keys() else 1) img_flip = img_meta['flip'] if 'flip' in img_meta.keys() else False img_crop_offset = ( pts.new_tensor(img_meta['img_crop_offset']) if 'img_crop_offset' in img_meta.keys() else 0) img_pts = point_sample( img_meta, img_feats, pts, pts.new_tensor(img_meta['lidar2img']), img_scale_factor, img_crop_offset, img_flip=img_flip, img_pad_shape=img_meta['input_shape'][:2], img_shape=img_meta['img_shape'][:2], aligned=self.aligned, padding_mode=self.padding_mode, align_corners=self.align_corners, ) return img_pts