Shortcuts

Source code for mmdet3d.models.backbones.pointnet2_sa_msg

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.cnn import ConvModule
from mmcv.runner import auto_fp16
from torch import nn as nn

from mmdet3d.ops import build_sa_module
from ..builder import BACKBONES
from .base_pointnet import BasePointNet


[docs]@BACKBONES.register_module() class PointNet2SAMSG(BasePointNet): """PointNet2 with Multi-scale grouping. Args: in_channels (int): Input channels of point cloud. num_points (tuple[int]): The number of points which each SA module samples. radii (tuple[float]): Sampling radii of each SA module. num_samples (tuple[int]): The number of samples for ball query in each SA module. sa_channels (tuple[tuple[int]]): Out channels of each mlp in SA module. aggregation_channels (tuple[int]): Out channels of aggregation multi-scale grouping features. fps_mods (tuple[int]): Mod of FPS for each SA module. fps_sample_range_lists (tuple[tuple[int]]): The number of sampling points which each SA module samples. dilated_group (tuple[bool]): Whether to use dilated ball query for out_indices (Sequence[int]): Output from which stages. norm_cfg (dict): Config of normalization layer. sa_cfg (dict): Config of set abstraction module, which may contain the following keys and values: - pool_mod (str): Pool method ('max' or 'avg') for SA modules. - use_xyz (bool): Whether to use xyz as a part of features. - normalize_xyz (bool): Whether to normalize xyz with radii in each SA module. """ def __init__(self, in_channels, num_points=(2048, 1024, 512, 256), radii=((0.2, 0.4, 0.8), (0.4, 0.8, 1.6), (1.6, 3.2, 4.8)), num_samples=((32, 32, 64), (32, 32, 64), (32, 32, 32)), sa_channels=(((16, 16, 32), (16, 16, 32), (32, 32, 64)), ((64, 64, 128), (64, 64, 128), (64, 96, 128)), ((128, 128, 256), (128, 192, 256), (128, 256, 256))), aggregation_channels=(64, 128, 256), fps_mods=(('D-FPS'), ('FS'), ('F-FPS', 'D-FPS')), fps_sample_range_lists=((-1), (-1), (512, -1)), dilated_group=(True, True, True), out_indices=(2, ), norm_cfg=dict(type='BN2d'), sa_cfg=dict( type='PointSAModuleMSG', pool_mod='max', use_xyz=True, normalize_xyz=False), init_cfg=None): super().__init__(init_cfg=init_cfg) self.num_sa = len(sa_channels) self.out_indices = out_indices assert max(out_indices) < self.num_sa assert len(num_points) == len(radii) == len(num_samples) == len( sa_channels) if aggregation_channels is not None: assert len(sa_channels) == len(aggregation_channels) else: aggregation_channels = [None] * len(sa_channels) self.SA_modules = nn.ModuleList() self.aggregation_mlps = nn.ModuleList() sa_in_channel = in_channels - 3 # number of channels without xyz skip_channel_list = [sa_in_channel] for sa_index in range(self.num_sa): cur_sa_mlps = list(sa_channels[sa_index]) sa_out_channel = 0 for radius_index in range(len(radii[sa_index])): cur_sa_mlps[radius_index] = [sa_in_channel] + list( cur_sa_mlps[radius_index]) sa_out_channel += cur_sa_mlps[radius_index][-1] if isinstance(fps_mods[sa_index], tuple): cur_fps_mod = list(fps_mods[sa_index]) else: cur_fps_mod = list([fps_mods[sa_index]]) if isinstance(fps_sample_range_lists[sa_index], tuple): cur_fps_sample_range_list = list( fps_sample_range_lists[sa_index]) else: cur_fps_sample_range_list = list( [fps_sample_range_lists[sa_index]]) self.SA_modules.append( build_sa_module( num_point=num_points[sa_index], radii=radii[sa_index], sample_nums=num_samples[sa_index], mlp_channels=cur_sa_mlps, fps_mod=cur_fps_mod, fps_sample_range_list=cur_fps_sample_range_list, dilated_group=dilated_group[sa_index], norm_cfg=norm_cfg, cfg=sa_cfg, bias=True)) skip_channel_list.append(sa_out_channel) cur_aggregation_channel = aggregation_channels[sa_index] if cur_aggregation_channel is None: self.aggregation_mlps.append(None) sa_in_channel = sa_out_channel else: self.aggregation_mlps.append( ConvModule( sa_out_channel, cur_aggregation_channel, conv_cfg=dict(type='Conv1d'), norm_cfg=dict(type='BN1d'), kernel_size=1, bias=True)) sa_in_channel = cur_aggregation_channel
[docs] @auto_fp16(apply_to=('points', )) def forward(self, points): """Forward pass. Args: points (torch.Tensor): point coordinates with features, with shape (B, N, 3 + input_feature_dim). Returns: dict[str, torch.Tensor]: Outputs of the last SA module. - sa_xyz (torch.Tensor): The coordinates of sa features. - sa_features (torch.Tensor): The features from the last Set Aggregation Layers. - sa_indices (torch.Tensor): Indices of the input points. """ xyz, features = self._split_point_feats(points) batch, num_points = xyz.shape[:2] indices = xyz.new_tensor(range(num_points)).unsqueeze(0).repeat( batch, 1).long() sa_xyz = [xyz] sa_features = [features] sa_indices = [indices] out_sa_xyz = [xyz] out_sa_features = [features] out_sa_indices = [indices] for i in range(self.num_sa): cur_xyz, cur_features, cur_indices = self.SA_modules[i]( sa_xyz[i], sa_features[i]) if self.aggregation_mlps[i] is not None: cur_features = self.aggregation_mlps[i](cur_features) sa_xyz.append(cur_xyz) sa_features.append(cur_features) sa_indices.append( torch.gather(sa_indices[-1], 1, cur_indices.long())) if i in self.out_indices: out_sa_xyz.append(sa_xyz[-1]) out_sa_features.append(sa_features[-1]) out_sa_indices.append(sa_indices[-1]) return dict( sa_xyz=out_sa_xyz, sa_features=out_sa_features, sa_indices=out_sa_indices)
Read the Docs v: dev
Versions
latest
stable
v1.0.0rc1
v1.0.0rc0
v0.18.1
v0.18.0
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.