Source code for mmdet3d.models.backbones.pointnet2_sa_msg

import torch
from mmcv.cnn import ConvModule
from mmcv.runner import auto_fp16
from torch import nn as nn

from mmdet3d.ops import build_sa_module
from mmdet.models import BACKBONES
from .base_pointnet import BasePointNet


[docs]@BACKBONES.register_module() class PointNet2SAMSG(BasePointNet): """PointNet2 with Multi-scale grouping. Args: in_channels (int): Input channels of point cloud. num_points (tuple[int]): The number of points which each SA module samples. radii (tuple[float]): Sampling radii of each SA module. num_samples (tuple[int]): The number of samples for ball query in each SA module. sa_channels (tuple[tuple[int]]): Out channels of each mlp in SA module. aggregation_channels (tuple[int]): Out channels of aggregation multi-scale grouping features. fps_mods (tuple[int]): Mod of FPS for each SA module. fps_sample_range_lists (tuple[tuple[int]]): The number of sampling points which each SA module samples. dilated_group (tuple[bool]): Whether to use dilated ball query for out_indices (Sequence[int]): Output from which stages. norm_cfg (dict): Config of normalization layer. sa_cfg (dict): Config of set abstraction module, which may contain the following keys and values: - pool_mod (str): Pool method ('max' or 'avg') for SA modules. - use_xyz (bool): Whether to use xyz as a part of features. - normalize_xyz (bool): Whether to normalize xyz with radii in each SA module. """ def __init__(self, in_channels, num_points=(2048, 1024, 512, 256), radii=((0.2, 0.4, 0.8), (0.4, 0.8, 1.6), (1.6, 3.2, 4.8)), num_samples=((32, 32, 64), (32, 32, 64), (32, 32, 32)), sa_channels=(((16, 16, 32), (16, 16, 32), (32, 32, 64)), ((64, 64, 128), (64, 64, 128), (64, 96, 128)), ((128, 128, 256), (128, 192, 256), (128, 256, 256))), aggregation_channels=(64, 128, 256), fps_mods=(('D-FPS'), ('FS'), ('F-FPS', 'D-FPS')), fps_sample_range_lists=((-1), (-1), (512, -1)), dilated_group=(True, True, True), out_indices=(2, ), norm_cfg=dict(type='BN2d'), sa_cfg=dict( type='PointSAModuleMSG', pool_mod='max', use_xyz=True, normalize_xyz=False), init_cfg=None): super().__init__(init_cfg=init_cfg) self.num_sa = len(sa_channels) self.out_indices = out_indices assert max(out_indices) < self.num_sa assert len(num_points) == len(radii) == len(num_samples) == len( sa_channels) == len(aggregation_channels) self.SA_modules = nn.ModuleList() self.aggregation_mlps = nn.ModuleList() sa_in_channel = in_channels - 3 # number of channels without xyz skip_channel_list = [sa_in_channel] for sa_index in range(self.num_sa): cur_sa_mlps = list(sa_channels[sa_index]) sa_out_channel = 0 for radius_index in range(len(radii[sa_index])): cur_sa_mlps[radius_index] = [sa_in_channel] + list( cur_sa_mlps[radius_index]) sa_out_channel += cur_sa_mlps[radius_index][-1] if isinstance(fps_mods[sa_index], tuple): cur_fps_mod = list(fps_mods[sa_index]) else: cur_fps_mod = list([fps_mods[sa_index]]) if isinstance(fps_sample_range_lists[sa_index], tuple): cur_fps_sample_range_list = list( fps_sample_range_lists[sa_index]) else: cur_fps_sample_range_list = list( [fps_sample_range_lists[sa_index]]) self.SA_modules.append( build_sa_module( num_point=num_points[sa_index], radii=radii[sa_index], sample_nums=num_samples[sa_index], mlp_channels=cur_sa_mlps, fps_mod=cur_fps_mod, fps_sample_range_list=cur_fps_sample_range_list, dilated_group=dilated_group[sa_index], norm_cfg=norm_cfg, cfg=sa_cfg, bias=True)) skip_channel_list.append(sa_out_channel) cur_aggregation_channel = aggregation_channels[sa_index] if cur_aggregation_channel is None: self.aggregation_mlps.append(None) sa_in_channel = sa_out_channel else: self.aggregation_mlps.append( ConvModule( sa_out_channel, cur_aggregation_channel, conv_cfg=dict(type='Conv1d'), norm_cfg=dict(type='BN1d'), kernel_size=1, bias=True)) sa_in_channel = cur_aggregation_channel
[docs] @auto_fp16(apply_to=('points', )) def forward(self, points): """Forward pass. Args: points (torch.Tensor): point coordinates with features, with shape (B, N, 3 + input_feature_dim). Returns: dict[str, torch.Tensor]: Outputs of the last SA module. - sa_xyz (torch.Tensor): The coordinates of sa features. - sa_features (torch.Tensor): The features from the last Set Aggregation Layers. - sa_indices (torch.Tensor): Indices of the \ input points. """ xyz, features = self._split_point_feats(points) batch, num_points = xyz.shape[:2] indices = xyz.new_tensor(range(num_points)).unsqueeze(0).repeat( batch, 1).long() sa_xyz = [xyz] sa_features = [features] sa_indices = [indices] out_sa_xyz = [xyz] out_sa_features = [features] out_sa_indices = [indices] for i in range(self.num_sa): cur_xyz, cur_features, cur_indices = self.SA_modules[i]( sa_xyz[i], sa_features[i]) if self.aggregation_mlps[i] is not None: cur_features = self.aggregation_mlps[i](cur_features) sa_xyz.append(cur_xyz) sa_features.append(cur_features) sa_indices.append( torch.gather(sa_indices[-1], 1, cur_indices.long())) if i in self.out_indices: out_sa_xyz.append(sa_xyz[-1]) out_sa_features.append(sa_features[-1]) out_sa_indices.append(sa_indices[-1]) return dict( sa_xyz=out_sa_xyz, sa_features=out_sa_features, sa_indices=out_sa_indices)