• Docs >
  • Module code >
  • mmdet3d.models.roi_heads.roi_extractors.single_roiaware_extractor
Shortcuts

Source code for mmdet3d.models.roi_heads.roi_extractors.single_roiaware_extractor

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv import ops
from mmcv.runner import BaseModule

from mmdet3d.models.builder import ROI_EXTRACTORS


[docs]@ROI_EXTRACTORS.register_module() class Single3DRoIAwareExtractor(BaseModule): """Point-wise roi-aware Extractor. Extract Point-wise roi features. Args: roi_layer (dict): The config of roi layer. """ def __init__(self, roi_layer=None, init_cfg=None): super(Single3DRoIAwareExtractor, self).__init__(init_cfg=init_cfg) self.roi_layer = self.build_roi_layers(roi_layer)
[docs] def build_roi_layers(self, layer_cfg): """Build roi layers using `layer_cfg`""" cfg = layer_cfg.copy() layer_type = cfg.pop('type') assert hasattr(ops, layer_type) layer_cls = getattr(ops, layer_type) roi_layers = layer_cls(**cfg) return roi_layers
[docs] def forward(self, feats, coordinate, batch_inds, rois): """Extract point-wise roi features. Args: feats (torch.FloatTensor): Point-wise features with shape (batch, npoints, channels) for pooling. coordinate (torch.FloatTensor): Coordinate of each point. batch_inds (torch.LongTensor): Indicate the batch of each point. rois (torch.FloatTensor): Roi boxes with batch indices. Returns: torch.FloatTensor: Pooled features """ pooled_roi_feats = [] for batch_idx in range(int(batch_inds.max()) + 1): roi_inds = (rois[..., 0].int() == batch_idx) coors_inds = (batch_inds.int() == batch_idx) pooled_roi_feat = self.roi_layer(rois[..., 1:][roi_inds], coordinate[coors_inds], feats[coors_inds]) pooled_roi_feats.append(pooled_roi_feat) pooled_roi_feats = torch.cat(pooled_roi_feats, 0) return pooled_roi_feats
Read the Docs v: dev
Versions
latest
stable
v1.0.0rc1
v1.0.0rc0
v0.18.1
v0.18.0
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.