Source code for mmdet3d.models.detectors.mvx_two_stage

import mmcv
import torch
import warnings
from mmcv.parallel import DataContainer as DC
from mmcv.runner import force_fp32
from os import path as osp
from torch.nn import functional as F

from mmdet3d.core import (Box3DMode, Coord3DMode, bbox3d2result,
                          merge_aug_bboxes_3d, show_result)
from mmdet3d.ops import Voxelization
from mmdet.core import multi_apply
from mmdet.models import DETECTORS
from .. import builder
from .base import Base3DDetector


[docs]@DETECTORS.register_module() class MVXTwoStageDetector(Base3DDetector): """Base class of Multi-modality VoxelNet.""" def __init__(self, pts_voxel_layer=None, pts_voxel_encoder=None, pts_middle_encoder=None, pts_fusion_layer=None, img_backbone=None, pts_backbone=None, img_neck=None, pts_neck=None, pts_bbox_head=None, img_roi_head=None, img_rpn_head=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None): super(MVXTwoStageDetector, self).__init__(init_cfg=init_cfg) if pts_voxel_layer: self.pts_voxel_layer = Voxelization(**pts_voxel_layer) if pts_voxel_encoder: self.pts_voxel_encoder = builder.build_voxel_encoder( pts_voxel_encoder) if pts_middle_encoder: self.pts_middle_encoder = builder.build_middle_encoder( pts_middle_encoder) if pts_backbone: self.pts_backbone = builder.build_backbone(pts_backbone) if pts_fusion_layer: self.pts_fusion_layer = builder.build_fusion_layer( pts_fusion_layer) if pts_neck is not None: self.pts_neck = builder.build_neck(pts_neck) if pts_bbox_head: pts_train_cfg = train_cfg.pts if train_cfg else None pts_bbox_head.update(train_cfg=pts_train_cfg) pts_test_cfg = test_cfg.pts if test_cfg else None pts_bbox_head.update(test_cfg=pts_test_cfg) self.pts_bbox_head = builder.build_head(pts_bbox_head) if img_backbone: self.img_backbone = builder.build_backbone(img_backbone) if img_neck is not None: self.img_neck = builder.build_neck(img_neck) if img_rpn_head is not None: self.img_rpn_head = builder.build_head(img_rpn_head) if img_roi_head is not None: self.img_roi_head = builder.build_head(img_roi_head) self.train_cfg = train_cfg self.test_cfg = test_cfg if pretrained is None: img_pretrained = None pts_pretrained = None elif isinstance(pretrained, dict): img_pretrained = pretrained.get('img', None) pts_pretrained = pretrained.get('pts', None) else: raise ValueError( f'pretrained should be a dict, got {type(pretrained)}') if self.with_img_backbone: if img_pretrained is not None: warnings.warn('DeprecationWarning: pretrained is a deprecated \ key, please consider using init_cfg') self.img_backbone.init_cfg = dict( type='Pretrained', checkpoint=img_pretrained) if self.with_img_roi_head: if img_pretrained is not None: warnings.warn('DeprecationWarning: pretrained is a deprecated \ key, please consider using init_cfg') self.img_roi_head.init_cfg = dict( type='Pretrained', checkpoint=img_pretrained) if self.with_pts_backbone: if img_pretrained is not None: warnings.warn('DeprecationWarning: pretrained is a deprecated \ key, please consider using init_cfg') self.pts_backbone.init_cfg = dict( type='Pretrained', checkpoint=pts_pretrained) @property def with_img_shared_head(self): """bool: Whether the detector has a shared head in image branch.""" return hasattr(self, 'img_shared_head') and self.img_shared_head is not None @property def with_pts_bbox(self): """bool: Whether the detector has a 3D box head.""" return hasattr(self, 'pts_bbox_head') and self.pts_bbox_head is not None @property def with_img_bbox(self): """bool: Whether the detector has a 2D image box head.""" return hasattr(self, 'img_bbox_head') and self.img_bbox_head is not None @property def with_img_backbone(self): """bool: Whether the detector has a 2D image backbone.""" return hasattr(self, 'img_backbone') and self.img_backbone is not None @property def with_pts_backbone(self): """bool: Whether the detector has a 3D backbone.""" return hasattr(self, 'pts_backbone') and self.pts_backbone is not None @property def with_fusion(self): """bool: Whether the detector has a fusion layer.""" return hasattr(self, 'pts_fusion_layer') and self.fusion_layer is not None @property def with_img_neck(self): """bool: Whether the detector has a neck in image branch.""" return hasattr(self, 'img_neck') and self.img_neck is not None @property def with_pts_neck(self): """bool: Whether the detector has a neck in 3D detector branch.""" return hasattr(self, 'pts_neck') and self.pts_neck is not None @property def with_img_rpn(self): """bool: Whether the detector has a 2D RPN in image detector branch.""" return hasattr(self, 'img_rpn_head') and self.img_rpn_head is not None @property def with_img_roi_head(self): """bool: Whether the detector has a RoI Head in image branch.""" return hasattr(self, 'img_roi_head') and self.img_roi_head is not None @property def with_voxel_encoder(self): """bool: Whether the detector has a voxel encoder.""" return hasattr(self, 'voxel_encoder') and self.voxel_encoder is not None @property def with_middle_encoder(self): """bool: Whether the detector has a middle encoder.""" return hasattr(self, 'middle_encoder') and self.middle_encoder is not None
[docs] def extract_img_feat(self, img, img_metas): """Extract features of images.""" if self.with_img_backbone and img is not None: input_shape = img.shape[-2:] # update real input shape of each single img for img_meta in img_metas: img_meta.update(input_shape=input_shape) if img.dim() == 5 and img.size(0) == 1: img.squeeze_() elif img.dim() == 5 and img.size(0) > 1: B, N, C, H, W = img.size() img = img.view(B * N, C, H, W) img_feats = self.img_backbone(img) else: return None if self.with_img_neck: img_feats = self.img_neck(img_feats) return img_feats
[docs] def extract_pts_feat(self, pts, img_feats, img_metas): """Extract features of points.""" if not self.with_pts_bbox: return None voxels, num_points, coors = self.voxelize(pts) voxel_features = self.pts_voxel_encoder(voxels, num_points, coors, img_feats, img_metas) batch_size = coors[-1, 0] + 1 x = self.pts_middle_encoder(voxel_features, coors, batch_size) x = self.pts_backbone(x) if self.with_pts_neck: x = self.pts_neck(x) return x
[docs] def extract_feat(self, points, img, img_metas): """Extract features from images and points.""" img_feats = self.extract_img_feat(img, img_metas) pts_feats = self.extract_pts_feat(points, img_feats, img_metas) return (img_feats, pts_feats)
[docs] @torch.no_grad() @force_fp32() def voxelize(self, points): """Apply dynamic voxelization to points. Args: points (list[torch.Tensor]): Points of each sample. Returns: tuple[torch.Tensor]: Concatenated points, number of points per voxel, and coordinates. """ voxels, coors, num_points = [], [], [] for res in points: res_voxels, res_coors, res_num_points = self.pts_voxel_layer(res) voxels.append(res_voxels) coors.append(res_coors) num_points.append(res_num_points) voxels = torch.cat(voxels, dim=0) num_points = torch.cat(num_points, dim=0) coors_batch = [] for i, coor in enumerate(coors): coor_pad = F.pad(coor, (1, 0), mode='constant', value=i) coors_batch.append(coor_pad) coors_batch = torch.cat(coors_batch, dim=0) return voxels, num_points, coors_batch
[docs] def forward_train(self, points=None, img_metas=None, gt_bboxes_3d=None, gt_labels_3d=None, gt_labels=None, gt_bboxes=None, img=None, proposals=None, gt_bboxes_ignore=None): """Forward training function. Args: points (list[torch.Tensor], optional): Points of each sample. Defaults to None. img_metas (list[dict], optional): Meta information of each sample. Defaults to None. gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional): Ground truth 3D boxes. Defaults to None. gt_labels_3d (list[torch.Tensor], optional): Ground truth labels of 3D boxes. Defaults to None. gt_labels (list[torch.Tensor], optional): Ground truth labels of 2D boxes in images. Defaults to None. gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in images. Defaults to None. img (torch.Tensor optional): Images of each sample with shape (N, C, H, W). Defaults to None. proposals ([list[torch.Tensor], optional): Predicted proposals used for training Fast RCNN. Defaults to None. gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth 2D boxes in images to be ignored. Defaults to None. Returns: dict: Losses of different branches. """ img_feats, pts_feats = self.extract_feat( points, img=img, img_metas=img_metas) losses = dict() if pts_feats: losses_pts = self.forward_pts_train(pts_feats, gt_bboxes_3d, gt_labels_3d, img_metas, gt_bboxes_ignore) losses.update(losses_pts) if img_feats: losses_img = self.forward_img_train( img_feats, img_metas=img_metas, gt_bboxes=gt_bboxes, gt_labels=gt_labels, gt_bboxes_ignore=gt_bboxes_ignore, proposals=proposals) losses.update(losses_img) return losses
[docs] def forward_pts_train(self, pts_feats, gt_bboxes_3d, gt_labels_3d, img_metas, gt_bboxes_ignore=None): """Forward function for point cloud branch. Args: pts_feats (list[torch.Tensor]): Features of point cloud branch gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth boxes for each sample. gt_labels_3d (list[torch.Tensor]): Ground truth labels for boxes of each sampole img_metas (list[dict]): Meta information of samples. gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth boxes to be ignored. Defaults to None. Returns: dict: Losses of each branch. """ outs = self.pts_bbox_head(pts_feats) loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_metas) losses = self.pts_bbox_head.loss( *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) return losses
[docs] def forward_img_train(self, x, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore=None, proposals=None, **kwargs): """Forward function for image branch. This function works similar to the forward function of Faster R-CNN. Args: x (list[torch.Tensor]): Image features of shape (B, C, H, W) of multiple levels. img_metas (list[dict]): Meta information of images. gt_bboxes (list[torch.Tensor]): Ground truth boxes of each image sample. gt_labels (list[torch.Tensor]): Ground truth labels of boxes. gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth boxes to be ignored. Defaults to None. proposals (list[torch.Tensor], optional): Proposals of each sample. Defaults to None. Returns: dict: Losses of each branch. """ losses = dict() # RPN forward and loss if self.with_img_rpn: rpn_outs = self.img_rpn_head(x) rpn_loss_inputs = rpn_outs + (gt_bboxes, img_metas, self.train_cfg.img_rpn) rpn_losses = self.img_rpn_head.loss( *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) losses.update(rpn_losses) proposal_cfg = self.train_cfg.get('img_rpn_proposal', self.test_cfg.img_rpn) proposal_inputs = rpn_outs + (img_metas, proposal_cfg) proposal_list = self.img_rpn_head.get_bboxes(*proposal_inputs) else: proposal_list = proposals # bbox head forward and loss if self.with_img_bbox: # bbox head forward and loss img_roi_losses = self.img_roi_head.forward_train( x, img_metas, proposal_list, gt_bboxes, gt_labels, gt_bboxes_ignore, **kwargs) losses.update(img_roi_losses) return losses
[docs] def simple_test_img(self, x, img_metas, proposals=None, rescale=False): """Test without augmentation.""" if proposals is None: proposal_list = self.simple_test_rpn(x, img_metas, self.test_cfg.img_rpn) else: proposal_list = proposals return self.img_roi_head.simple_test( x, proposal_list, img_metas, rescale=rescale)
[docs] def simple_test_rpn(self, x, img_metas, rpn_test_cfg): """RPN test function.""" rpn_outs = self.img_rpn_head(x) proposal_inputs = rpn_outs + (img_metas, rpn_test_cfg) proposal_list = self.img_rpn_head.get_bboxes(*proposal_inputs) return proposal_list
[docs] def simple_test_pts(self, x, img_metas, rescale=False): """Test function of point cloud branch.""" outs = self.pts_bbox_head(x) bbox_list = self.pts_bbox_head.get_bboxes( *outs, img_metas, rescale=rescale) bbox_results = [ bbox3d2result(bboxes, scores, labels) for bboxes, scores, labels in bbox_list ] return bbox_results
[docs] def simple_test(self, points, img_metas, img=None, rescale=False): """Test function without augmentaiton.""" img_feats, pts_feats = self.extract_feat( points, img=img, img_metas=img_metas) bbox_list = [dict() for i in range(len(img_metas))] if pts_feats and self.with_pts_bbox: bbox_pts = self.simple_test_pts( pts_feats, img_metas, rescale=rescale) for result_dict, pts_bbox in zip(bbox_list, bbox_pts): result_dict['pts_bbox'] = pts_bbox if img_feats and self.with_img_bbox: bbox_img = self.simple_test_img( img_feats, img_metas, rescale=rescale) for result_dict, img_bbox in zip(bbox_list, bbox_img): result_dict['img_bbox'] = img_bbox return bbox_list
[docs] def aug_test(self, points, img_metas, imgs=None, rescale=False): """Test function with augmentaiton.""" img_feats, pts_feats = self.extract_feats(points, img_metas, imgs) bbox_list = dict() if pts_feats and self.with_pts_bbox: bbox_pts = self.aug_test_pts(pts_feats, img_metas, rescale) bbox_list.update(pts_bbox=bbox_pts) return [bbox_list]
[docs] def extract_feats(self, points, img_metas, imgs=None): """Extract point and image features of multiple samples.""" if imgs is None: imgs = [None] * len(img_metas) img_feats, pts_feats = multi_apply(self.extract_feat, points, imgs, img_metas) return img_feats, pts_feats
[docs] def aug_test_pts(self, feats, img_metas, rescale=False): """Test function of point cloud branch with augmentaiton.""" # only support aug_test for one sample aug_bboxes = [] for x, img_meta in zip(feats, img_metas): outs = self.pts_bbox_head(x) bbox_list = self.pts_bbox_head.get_bboxes( *outs, img_meta, rescale=rescale) bbox_list = [ dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels) for bboxes, scores, labels in bbox_list ] aug_bboxes.append(bbox_list[0]) # after merging, bboxes will be rescaled to the original image size merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas, self.pts_bbox_head.test_cfg) return merged_bboxes
[docs] def show_results(self, data, result, out_dir): """Results visualization. Args: data (dict): Input points and the information of the sample. result (dict): Prediction results. out_dir (str): Output directory of visualization result. """ for batch_id in range(len(result)): if isinstance(data['points'][0], DC): points = data['points'][0]._data[0][batch_id].numpy() elif mmcv.is_list_of(data['points'][0], torch.Tensor): points = data['points'][0][batch_id] else: ValueError(f"Unsupported data type {type(data['points'][0])} " f'for visualization!') if isinstance(data['img_metas'][0], DC): pts_filename = data['img_metas'][0]._data[0][batch_id][ 'pts_filename'] box_mode_3d = data['img_metas'][0]._data[0][batch_id][ 'box_mode_3d'] elif mmcv.is_list_of(data['img_metas'][0], dict): pts_filename = data['img_metas'][0][batch_id]['pts_filename'] box_mode_3d = data['img_metas'][0][batch_id]['box_mode_3d'] else: ValueError( f"Unsupported data type {type(data['img_metas'][0])} " f'for visualization!') file_name = osp.split(pts_filename)[-1].split('.')[0] assert out_dir is not None, 'Expect out_dir, got none.' inds = result[batch_id]['pts_bbox']['scores_3d'] > 0.1 pred_bboxes = result[batch_id]['pts_bbox']['boxes_3d'][inds] # for now we convert points and bbox into depth mode if (box_mode_3d == Box3DMode.CAM) or (box_mode_3d == Box3DMode.LIDAR): points = Coord3DMode.convert_point(points, Coord3DMode.LIDAR, Coord3DMode.DEPTH) pred_bboxes = Box3DMode.convert(pred_bboxes, box_mode_3d, Box3DMode.DEPTH) elif box_mode_3d != Box3DMode.DEPTH: ValueError( f'Unsupported box_mode_3d {box_mode_3d} for convertion!') pred_bboxes = pred_bboxes.tensor.cpu().numpy() show_result(points, None, pred_bboxes, out_dir, file_name)