Shortcuts

Source code for mmdet3d.models.backbones.pointnet2_sa_ssg

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.runner import auto_fp16
from torch import nn as nn

from mmdet3d.ops import PointFPModule, build_sa_module
from ..builder import BACKBONES
from .base_pointnet import BasePointNet


[docs]@BACKBONES.register_module() class PointNet2SASSG(BasePointNet): """PointNet2 with Single-scale grouping. Args: in_channels (int): Input channels of point cloud. num_points (tuple[int]): The number of points which each SA module samples. radius (tuple[float]): Sampling radii of each SA module. num_samples (tuple[int]): The number of samples for ball query in each SA module. sa_channels (tuple[tuple[int]]): Out channels of each mlp in SA module. fp_channels (tuple[tuple[int]]): Out channels of each mlp in FP module. norm_cfg (dict): Config of normalization layer. sa_cfg (dict): Config of set abstraction module, which may contain the following keys and values: - pool_mod (str): Pool method ('max' or 'avg') for SA modules. - use_xyz (bool): Whether to use xyz as a part of features. - normalize_xyz (bool): Whether to normalize xyz with radii in each SA module. """ def __init__(self, in_channels, num_points=(2048, 1024, 512, 256), radius=(0.2, 0.4, 0.8, 1.2), num_samples=(64, 32, 16, 16), sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256), (128, 128, 256)), fp_channels=((256, 256), (256, 256)), norm_cfg=dict(type='BN2d'), sa_cfg=dict( type='PointSAModule', pool_mod='max', use_xyz=True, normalize_xyz=True), init_cfg=None): super().__init__(init_cfg=init_cfg) self.num_sa = len(sa_channels) self.num_fp = len(fp_channels) assert len(num_points) == len(radius) == len(num_samples) == len( sa_channels) assert len(sa_channels) >= len(fp_channels) self.SA_modules = nn.ModuleList() sa_in_channel = in_channels - 3 # number of channels without xyz skip_channel_list = [sa_in_channel] for sa_index in range(self.num_sa): cur_sa_mlps = list(sa_channels[sa_index]) cur_sa_mlps = [sa_in_channel] + cur_sa_mlps sa_out_channel = cur_sa_mlps[-1] self.SA_modules.append( build_sa_module( num_point=num_points[sa_index], radius=radius[sa_index], num_sample=num_samples[sa_index], mlp_channels=cur_sa_mlps, norm_cfg=norm_cfg, cfg=sa_cfg)) skip_channel_list.append(sa_out_channel) sa_in_channel = sa_out_channel self.FP_modules = nn.ModuleList() fp_source_channel = skip_channel_list.pop() fp_target_channel = skip_channel_list.pop() for fp_index in range(len(fp_channels)): cur_fp_mlps = list(fp_channels[fp_index]) cur_fp_mlps = [fp_source_channel + fp_target_channel] + cur_fp_mlps self.FP_modules.append(PointFPModule(mlp_channels=cur_fp_mlps)) if fp_index != len(fp_channels) - 1: fp_source_channel = cur_fp_mlps[-1] fp_target_channel = skip_channel_list.pop()
[docs] @auto_fp16(apply_to=('points', )) def forward(self, points): """Forward pass. Args: points (torch.Tensor): point coordinates with features, with shape (B, N, 3 + input_feature_dim). Returns: dict[str, list[torch.Tensor]]: Outputs after SA and FP modules. - fp_xyz (list[torch.Tensor]): The coordinates of each fp features. - fp_features (list[torch.Tensor]): The features from each Feature Propagate Layers. - fp_indices (list[torch.Tensor]): Indices of the input points. """ xyz, features = self._split_point_feats(points) batch, num_points = xyz.shape[:2] indices = xyz.new_tensor(range(num_points)).unsqueeze(0).repeat( batch, 1).long() sa_xyz = [xyz] sa_features = [features] sa_indices = [indices] for i in range(self.num_sa): cur_xyz, cur_features, cur_indices = self.SA_modules[i]( sa_xyz[i], sa_features[i]) sa_xyz.append(cur_xyz) sa_features.append(cur_features) sa_indices.append( torch.gather(sa_indices[-1], 1, cur_indices.long())) fp_xyz = [sa_xyz[-1]] fp_features = [sa_features[-1]] fp_indices = [sa_indices[-1]] for i in range(self.num_fp): fp_features.append(self.FP_modules[i]( sa_xyz[self.num_sa - i - 1], sa_xyz[self.num_sa - i], sa_features[self.num_sa - i - 1], fp_features[-1])) fp_xyz.append(sa_xyz[self.num_sa - i - 1]) fp_indices.append(sa_indices[self.num_sa - i - 1]) ret = dict( fp_xyz=fp_xyz, fp_features=fp_features, fp_indices=fp_indices, sa_xyz=sa_xyz, sa_features=sa_features, sa_indices=sa_indices) return ret
Read the Docs v: dev
Versions
latest
stable
v1.0.0rc1
v1.0.0rc0
v0.18.1
v0.18.0
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.