Source code for mmdet3d.models.necks.pointnet2_fp_neck
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import BaseModule
from torch import nn as nn
from mmdet3d.ops import PointFPModule
from ..builder import NECKS
[docs]@NECKS.register_module()
class PointNetFPNeck(BaseModule):
r"""PointNet FP Module used in PointRCNN.
Refer to the `official code <https://github.com/charlesq34/pointnet2>`_.
.. code-block:: none
sa_n ----------------------------------------
|
... --------------------------------- |
| |
sa_1 ------------- | |
| | |
sa_0 -> fp_0 -> fp_module ->fp_1 -> ... -> fp_module -> fp_n
sa_n including sa_xyz (torch.Tensor) and sa_features (torch.Tensor)
fp_n including fp_xyz (torch.Tensor) and fp_features (torch.Tensor)
Args:
fp_channels (tuple[tuple[int]]): Tuple of mlp channels in FP modules.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""
def __init__(self, fp_channels, init_cfg=None):
super(PointNetFPNeck, self).__init__(init_cfg=init_cfg)
self.num_fp = len(fp_channels)
self.FP_modules = nn.ModuleList()
for cur_fp_mlps in fp_channels:
self.FP_modules.append(PointFPModule(mlp_channels=cur_fp_mlps))
def _extract_input(self, feat_dict):
"""Extract inputs from features dictionary.
Args:
feat_dict (dict): Feature dict from backbone, which may contain
the following keys and values:
- sa_xyz (list[torch.Tensor]): Points of each sa module
in shape (N, 3).
- sa_features (list[torch.Tensor]): Output features of
each sa module in shape (N, M).
Returns:
list[torch.Tensor]: Coordinates of multiple levels of points.
list[torch.Tensor]: Features of multiple levels of points.
"""
sa_xyz = feat_dict['sa_xyz']
sa_features = feat_dict['sa_features']
assert len(sa_xyz) == len(sa_features)
return sa_xyz, sa_features
[docs] def forward(self, feat_dict):
"""Forward pass.
Args:
feat_dict (dict): Feature dict from backbone.
Returns:
dict[str, torch.Tensor]: Outputs of the Neck.
- fp_xyz (torch.Tensor): The coordinates of fp features.
- fp_features (torch.Tensor): The features from the last
feature propagation layers.
"""
sa_xyz, sa_features = self._extract_input(feat_dict)
fp_feature = sa_features[-1]
fp_xyz = sa_xyz[-1]
for i in range(self.num_fp):
# consume the points in a bottom-up manner
fp_feature = self.FP_modules[i](sa_xyz[-(i + 2)], sa_xyz[-(i + 1)],
sa_features[-(i + 2)], fp_feature)
fp_xyz = sa_xyz[-(i + 2)]
ret = dict(fp_xyz=fp_xyz, fp_features=fp_feature)
return ret