Shortcuts

Source code for mmdet3d.models.necks.pointnet2_fp_neck

# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import BaseModule
from torch import nn as nn

from mmdet3d.ops import PointFPModule
from ..builder import NECKS


[docs]@NECKS.register_module() class PointNetFPNeck(BaseModule): r"""PointNet FP Module used in PointRCNN. Refer to the `official code <https://github.com/charlesq34/pointnet2>`_. .. code-block:: none sa_n ---------------------------------------- | ... --------------------------------- | | | sa_1 ------------- | | | | | sa_0 -> fp_0 -> fp_module ->fp_1 -> ... -> fp_module -> fp_n sa_n including sa_xyz (torch.Tensor) and sa_features (torch.Tensor) fp_n including fp_xyz (torch.Tensor) and fp_features (torch.Tensor) Args: fp_channels (tuple[tuple[int]]): Tuple of mlp channels in FP modules. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None """ def __init__(self, fp_channels, init_cfg=None): super(PointNetFPNeck, self).__init__(init_cfg=init_cfg) self.num_fp = len(fp_channels) self.FP_modules = nn.ModuleList() for cur_fp_mlps in fp_channels: self.FP_modules.append(PointFPModule(mlp_channels=cur_fp_mlps)) def _extract_input(self, feat_dict): """Extract inputs from features dictionary. Args: feat_dict (dict): Feature dict from backbone, which may contain the following keys and values: - sa_xyz (list[torch.Tensor]): Points of each sa module in shape (N, 3). - sa_features (list[torch.Tensor]): Output features of each sa module in shape (N, M). Returns: list[torch.Tensor]: Coordinates of multiple levels of points. list[torch.Tensor]: Features of multiple levels of points. """ sa_xyz = feat_dict['sa_xyz'] sa_features = feat_dict['sa_features'] assert len(sa_xyz) == len(sa_features) return sa_xyz, sa_features
[docs] def forward(self, feat_dict): """Forward pass. Args: feat_dict (dict): Feature dict from backbone. Returns: dict[str, torch.Tensor]: Outputs of the Neck. - fp_xyz (torch.Tensor): The coordinates of fp features. - fp_features (torch.Tensor): The features from the last feature propagation layers. """ sa_xyz, sa_features = self._extract_input(feat_dict) fp_feature = sa_features[-1] fp_xyz = sa_xyz[-1] for i in range(self.num_fp): # consume the points in a bottom-up manner fp_feature = self.FP_modules[i](sa_xyz[-(i + 2)], sa_xyz[-(i + 1)], sa_features[-(i + 2)], fp_feature) fp_xyz = sa_xyz[-(i + 2)] ret = dict(fp_xyz=fp_xyz, fp_features=fp_feature) return ret
Read the Docs v: dev
Versions
latest
stable
v1.0.0rc1
v1.0.0rc0
v0.18.1
v0.18.0
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.