Source code for mmdet3d.models.roi_heads.roi_extractors.single_roiaware_extractor
import torch
from mmcv.runner import BaseModule
from mmdet3d import ops
from mmdet.models.builder import ROI_EXTRACTORS
[docs]@ROI_EXTRACTORS.register_module()
class Single3DRoIAwareExtractor(BaseModule):
"""Point-wise roi-aware Extractor.
Extract Point-wise roi features.
Args:
roi_layer (dict): The config of roi layer.
"""
def __init__(self, roi_layer=None, init_cfg=None):
super(Single3DRoIAwareExtractor, self).__init__(init_cfg=init_cfg)
self.roi_layer = self.build_roi_layers(roi_layer)
[docs] def build_roi_layers(self, layer_cfg):
"""Build roi layers using `layer_cfg`"""
cfg = layer_cfg.copy()
layer_type = cfg.pop('type')
assert hasattr(ops, layer_type)
layer_cls = getattr(ops, layer_type)
roi_layers = layer_cls(**cfg)
return roi_layers
[docs] def forward(self, feats, coordinate, batch_inds, rois):
"""Extract point-wise roi features.
Args:
feats (torch.FloatTensor): Point-wise features with
shape (batch, npoints, channels) for pooling.
coordinate (torch.FloatTensor): Coordinate of each point.
batch_inds (torch.LongTensor): Indicate the batch of each point.
rois (torch.FloatTensor): Roi boxes with batch indices.
Returns:
torch.FloatTensor: Pooled features
"""
pooled_roi_feats = []
for batch_idx in range(int(batch_inds.max()) + 1):
roi_inds = (rois[..., 0].int() == batch_idx)
coors_inds = (batch_inds.int() == batch_idx)
pooled_roi_feat = self.roi_layer(rois[..., 1:][roi_inds],
coordinate[coors_inds],
feats[coors_inds])
pooled_roi_feats.append(pooled_roi_feat)
pooled_roi_feats = torch.cat(pooled_roi_feats, 0)
return pooled_roi_feats