
Source code for mmdet3d.models.detectors.mvx_two_stage

# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from os import path as osp

import mmcv
import torch
from mmcv.ops import Voxelization
from mmcv.parallel import DataContainer as DC
from mmcv.runner import force_fp32
from torch.nn import functional as F

from mmdet3d.core import (Box3DMode, Coord3DMode, bbox3d2result,
                          merge_aug_bboxes_3d, show_result)
from mmdet.core import multi_apply
from .. import builder
from ..builder import DETECTORS
from .base import Base3DDetector

[docs]@DETECTORS.register_module() class MVXTwoStageDetector(Base3DDetector): """Base class of Multi-modality VoxelNet.""" def __init__(self, pts_voxel_layer=None, pts_voxel_encoder=None, pts_middle_encoder=None, pts_fusion_layer=None, img_backbone=None, pts_backbone=None, img_neck=None, pts_neck=None, pts_bbox_head=None, img_roi_head=None, img_rpn_head=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None): super(MVXTwoStageDetector, self).__init__(init_cfg=init_cfg) if pts_voxel_layer: self.pts_voxel_layer = Voxelization(**pts_voxel_layer) if pts_voxel_encoder: self.pts_voxel_encoder = builder.build_voxel_encoder( pts_voxel_encoder) if pts_middle_encoder: self.pts_middle_encoder = builder.build_middle_encoder( pts_middle_encoder) if pts_backbone: self.pts_backbone = builder.build_backbone(pts_backbone) if pts_fusion_layer: self.pts_fusion_layer = builder.build_fusion_layer( pts_fusion_layer) if pts_neck is not None: self.pts_neck = builder.build_neck(pts_neck) if pts_bbox_head: pts_train_cfg = train_cfg.pts if train_cfg else None pts_bbox_head.update(train_cfg=pts_train_cfg) pts_test_cfg = test_cfg.pts if test_cfg else None pts_bbox_head.update(test_cfg=pts_test_cfg) self.pts_bbox_head = builder.build_head(pts_bbox_head) if img_backbone: self.img_backbone = builder.build_backbone(img_backbone) if img_neck is not None: self.img_neck = builder.build_neck(img_neck) if img_rpn_head is not None: self.img_rpn_head = builder.build_head(img_rpn_head) if img_roi_head is not None: self.img_roi_head = builder.build_head(img_roi_head) self.train_cfg = train_cfg self.test_cfg = test_cfg if pretrained is None: img_pretrained = None pts_pretrained = None elif isinstance(pretrained, dict): img_pretrained = pretrained.get('img', None) pts_pretrained = pretrained.get('pts', None) else: raise ValueError( f'pretrained should be a dict, got {type(pretrained)}') if self.with_img_backbone: if img_pretrained is not None: warnings.warn('DeprecationWarning: pretrained is a deprecated ' 'key, please consider using init_cfg.') self.img_backbone.init_cfg = dict( type='Pretrained', checkpoint=img_pretrained) if self.with_img_roi_head: if img_pretrained is not None: warnings.warn('DeprecationWarning: pretrained is a deprecated ' 'key, please consider using init_cfg.') self.img_roi_head.init_cfg = dict( type='Pretrained', checkpoint=img_pretrained) if self.with_pts_backbone: if pts_pretrained is not None: warnings.warn('DeprecationWarning: pretrained is a deprecated ' 'key, please consider using init_cfg') self.pts_backbone.init_cfg = dict( type='Pretrained', checkpoint=pts_pretrained) @property def with_img_shared_head(self): """bool: Whether the detector has a shared head in image branch.""" return hasattr(self, 'img_shared_head') and self.img_shared_head is not None @property def with_pts_bbox(self): """bool: Whether the detector has a 3D box head.""" return hasattr(self, 'pts_bbox_head') and self.pts_bbox_head is not None @property def with_img_bbox(self): """bool: Whether the detector has a 2D image box head.""" return hasattr(self, 'img_bbox_head') and self.img_bbox_head is not None @property def with_img_backbone(self): """bool: Whether the detector has a 2D image backbone.""" return hasattr(self, 'img_backbone') and self.img_backbone is not None @property def with_pts_backbone(self): """bool: Whether the detector has a 3D backbone.""" return hasattr(self, 'pts_backbone') and self.pts_backbone is not None @property def with_fusion(self): """bool: Whether the detector has a fusion layer.""" return hasattr(self, 'pts_fusion_layer') and self.fusion_layer is not None @property def with_img_neck(self): """bool: Whether the detector has a neck in image branch.""" return hasattr(self, 'img_neck') and self.img_neck is not None @property def with_pts_neck(self): """bool: Whether the detector has a neck in 3D detector branch.""" return hasattr(self, 'pts_neck') and self.pts_neck is not None @property def with_img_rpn(self): """bool: Whether the detector has a 2D RPN in image detector branch.""" return hasattr(self, 'img_rpn_head') and self.img_rpn_head is not None @property def with_img_roi_head(self): """bool: Whether the detector has a RoI Head in image branch.""" return hasattr(self, 'img_roi_head') and self.img_roi_head is not None @property def with_voxel_encoder(self): """bool: Whether the detector has a voxel encoder.""" return hasattr(self, 'voxel_encoder') and self.voxel_encoder is not None @property def with_middle_encoder(self): """bool: Whether the detector has a middle encoder.""" return hasattr(self, 'middle_encoder') and self.middle_encoder is not None
[docs] def extract_img_feat(self, img, img_metas): """Extract features of images.""" if self.with_img_backbone and img is not None: input_shape = img.shape[-2:] # update real input shape of each single img for img_meta in img_metas: img_meta.update(input_shape=input_shape) if img.dim() == 5 and img.size(0) == 1: img.squeeze_() elif img.dim() == 5 and img.size(0) > 1: B, N, C, H, W = img.size() img = img.view(B * N, C, H, W) img_feats = self.img_backbone(img) else: return None if self.with_img_neck: img_feats = self.img_neck(img_feats) return img_feats
[docs] def extract_pts_feat(self, pts, img_feats, img_metas): """Extract features of points.""" if not self.with_pts_bbox: return None voxels, num_points, coors = self.voxelize(pts) voxel_features = self.pts_voxel_encoder(voxels, num_points, coors, img_feats, img_metas) batch_size = coors[-1, 0] + 1 x = self.pts_middle_encoder(voxel_features, coors, batch_size) x = self.pts_backbone(x) if self.with_pts_neck: x = self.pts_neck(x) return x
[docs] def extract_feat(self, points, img, img_metas): """Extract features from images and points.""" img_feats = self.extract_img_feat(img, img_metas) pts_feats = self.extract_pts_feat(points, img_feats, img_metas) return (img_feats, pts_feats)
[docs] @torch.no_grad() @force_fp32() def voxelize(self, points): """Apply dynamic voxelization to points. Args: points (list[torch.Tensor]): Points of each sample. Returns: tuple[torch.Tensor]: Concatenated points, number of points per voxel, and coordinates. """ voxels, coors, num_points = [], [], [] for res in points: res_voxels, res_coors, res_num_points = self.pts_voxel_layer(res) voxels.append(res_voxels) coors.append(res_coors) num_points.append(res_num_points) voxels =, dim=0) num_points =, dim=0) coors_batch = [] for i, coor in enumerate(coors): coor_pad = F.pad(coor, (1, 0), mode='constant', value=i) coors_batch.append(coor_pad) coors_batch =, dim=0) return voxels, num_points, coors_batch
[docs] def forward_train(self, points=None, img_metas=None, gt_bboxes_3d=None, gt_labels_3d=None, gt_labels=None, gt_bboxes=None, img=None, proposals=None, gt_bboxes_ignore=None): """Forward training function. Args: points (list[torch.Tensor], optional): Points of each sample. Defaults to None. img_metas (list[dict], optional): Meta information of each sample. Defaults to None. gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional): Ground truth 3D boxes. Defaults to None. gt_labels_3d (list[torch.Tensor], optional): Ground truth labels of 3D boxes. Defaults to None. gt_labels (list[torch.Tensor], optional): Ground truth labels of 2D boxes in images. Defaults to None. gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in images. Defaults to None. img (torch.Tensor, optional): Images of each sample with shape (N, C, H, W). Defaults to None. proposals ([list[torch.Tensor], optional): Predicted proposals used for training Fast RCNN. Defaults to None. gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth 2D boxes in images to be ignored. Defaults to None. Returns: dict: Losses of different branches. """ img_feats, pts_feats = self.extract_feat( points, img=img, img_metas=img_metas) losses = dict() if pts_feats: losses_pts = self.forward_pts_train(pts_feats, gt_bboxes_3d, gt_labels_3d, img_metas, gt_bboxes_ignore) losses.update(losses_pts) if img_feats: losses_img = self.forward_img_train( img_feats, img_metas=img_metas, gt_bboxes=gt_bboxes, gt_labels=gt_labels, gt_bboxes_ignore=gt_bboxes_ignore, proposals=proposals) losses.update(losses_img) return losses
[docs] def forward_pts_train(self, pts_feats, gt_bboxes_3d, gt_labels_3d, img_metas, gt_bboxes_ignore=None): """Forward function for point cloud branch. Args: pts_feats (list[torch.Tensor]): Features of point cloud branch gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth boxes for each sample. gt_labels_3d (list[torch.Tensor]): Ground truth labels for boxes of each sampole img_metas (list[dict]): Meta information of samples. gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth boxes to be ignored. Defaults to None. Returns: dict: Losses of each branch. """ outs = self.pts_bbox_head(pts_feats) loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_metas) losses = self.pts_bbox_head.loss( *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) return losses
[docs] def forward_img_train(self, x, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore=None, proposals=None, **kwargs): """Forward function for image branch. This function works similar to the forward function of Faster R-CNN. Args: x (list[torch.Tensor]): Image features of shape (B, C, H, W) of multiple levels. img_metas (list[dict]): Meta information of images. gt_bboxes (list[torch.Tensor]): Ground truth boxes of each image sample. gt_labels (list[torch.Tensor]): Ground truth labels of boxes. gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth boxes to be ignored. Defaults to None. proposals (list[torch.Tensor], optional): Proposals of each sample. Defaults to None. Returns: dict: Losses of each branch. """ losses = dict() # RPN forward and loss if self.with_img_rpn: rpn_outs = self.img_rpn_head(x) rpn_loss_inputs = rpn_outs + (gt_bboxes, img_metas, self.train_cfg.img_rpn) rpn_losses = self.img_rpn_head.loss( *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) losses.update(rpn_losses) proposal_cfg = self.train_cfg.get('img_rpn_proposal', self.test_cfg.img_rpn) proposal_inputs = rpn_outs + (img_metas, proposal_cfg) proposal_list = self.img_rpn_head.get_bboxes(*proposal_inputs) else: proposal_list = proposals # bbox head forward and loss if self.with_img_bbox: # bbox head forward and loss img_roi_losses = self.img_roi_head.forward_train( x, img_metas, proposal_list, gt_bboxes, gt_labels, gt_bboxes_ignore, **kwargs) losses.update(img_roi_losses) return losses
[docs] def simple_test_img(self, x, img_metas, proposals=None, rescale=False): """Test without augmentation.""" if proposals is None: proposal_list = self.simple_test_rpn(x, img_metas, self.test_cfg.img_rpn) else: proposal_list = proposals return self.img_roi_head.simple_test( x, proposal_list, img_metas, rescale=rescale)
[docs] def simple_test_rpn(self, x, img_metas, rpn_test_cfg): """RPN test function.""" rpn_outs = self.img_rpn_head(x) proposal_inputs = rpn_outs + (img_metas, rpn_test_cfg) proposal_list = self.img_rpn_head.get_bboxes(*proposal_inputs) return proposal_list
[docs] def simple_test_pts(self, x, img_metas, rescale=False): """Test function of point cloud branch.""" outs = self.pts_bbox_head(x) bbox_list = self.pts_bbox_head.get_bboxes( *outs, img_metas, rescale=rescale) bbox_results = [ bbox3d2result(bboxes, scores, labels) for bboxes, scores, labels in bbox_list ] return bbox_results
[docs] def simple_test(self, points, img_metas, img=None, rescale=False): """Test function without augmentaiton.""" img_feats, pts_feats = self.extract_feat( points, img=img, img_metas=img_metas) bbox_list = [dict() for i in range(len(img_metas))] if pts_feats and self.with_pts_bbox: bbox_pts = self.simple_test_pts( pts_feats, img_metas, rescale=rescale) for result_dict, pts_bbox in zip(bbox_list, bbox_pts): result_dict['pts_bbox'] = pts_bbox if img_feats and self.with_img_bbox: bbox_img = self.simple_test_img( img_feats, img_metas, rescale=rescale) for result_dict, img_bbox in zip(bbox_list, bbox_img): result_dict['img_bbox'] = img_bbox return bbox_list
[docs] def aug_test(self, points, img_metas, imgs=None, rescale=False): """Test function with augmentaiton.""" img_feats, pts_feats = self.extract_feats(points, img_metas, imgs) bbox_list = dict() if pts_feats and self.with_pts_bbox: bbox_pts = self.aug_test_pts(pts_feats, img_metas, rescale) bbox_list.update(pts_bbox=bbox_pts) return [bbox_list]
[docs] def extract_feats(self, points, img_metas, imgs=None): """Extract point and image features of multiple samples.""" if imgs is None: imgs = [None] * len(img_metas) img_feats, pts_feats = multi_apply(self.extract_feat, points, imgs, img_metas) return img_feats, pts_feats
[docs] def aug_test_pts(self, feats, img_metas, rescale=False): """Test function of point cloud branch with augmentaiton.""" # only support aug_test for one sample aug_bboxes = [] for x, img_meta in zip(feats, img_metas): outs = self.pts_bbox_head(x) bbox_list = self.pts_bbox_head.get_bboxes( *outs, img_meta, rescale=rescale) bbox_list = [ dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels) for bboxes, scores, labels in bbox_list ] aug_bboxes.append(bbox_list[0]) # after merging, bboxes will be rescaled to the original image size merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas, self.pts_bbox_head.test_cfg) return merged_bboxes
[docs] def show_results(self, data, result, out_dir): """Results visualization. Args: data (dict): Input points and the information of the sample. result (dict): Prediction results. out_dir (str): Output directory of visualization result. """ for batch_id in range(len(result)): if isinstance(data['points'][0], DC): points = data['points'][0]._data[0][batch_id].numpy() elif mmcv.is_list_of(data['points'][0], torch.Tensor): points = data['points'][0][batch_id] else: ValueError(f"Unsupported data type {type(data['points'][0])} " f'for visualization!') if isinstance(data['img_metas'][0], DC): pts_filename = data['img_metas'][0]._data[0][batch_id][ 'pts_filename'] box_mode_3d = data['img_metas'][0]._data[0][batch_id][ 'box_mode_3d'] elif mmcv.is_list_of(data['img_metas'][0], dict): pts_filename = data['img_metas'][0][batch_id]['pts_filename'] box_mode_3d = data['img_metas'][0][batch_id]['box_mode_3d'] else: ValueError( f"Unsupported data type {type(data['img_metas'][0])} " f'for visualization!') file_name = osp.split(pts_filename)[-1].split('.')[0] assert out_dir is not None, 'Expect out_dir, got none.' inds = result[batch_id]['pts_bbox']['scores_3d'] > 0.1 pred_bboxes = result[batch_id]['pts_bbox']['boxes_3d'][inds] # for now we convert points and bbox into depth mode if (box_mode_3d == Box3DMode.CAM) or (box_mode_3d == Box3DMode.LIDAR): points = Coord3DMode.convert_point(points, Coord3DMode.LIDAR, Coord3DMode.DEPTH) pred_bboxes = Box3DMode.convert(pred_bboxes, box_mode_3d, Box3DMode.DEPTH) elif box_mode_3d != Box3DMode.DEPTH: ValueError( f'Unsupported box_mode_3d {box_mode_3d} for conversion!') pred_bboxes = pred_bboxes.tensor.cpu().numpy() show_result(points, None, pred_bboxes, out_dir, file_name)
Read the Docs v: dev
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.