Source code for mmdet3d.models.detectors.h3dnet
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet3d.core import merge_aug_bboxes_3d
from ..builder import DETECTORS
from .two_stage import TwoStage3DDetector
[docs]@DETECTORS.register_module()
class H3DNet(TwoStage3DDetector):
r"""H3DNet model.
Please refer to the `paper <https://arxiv.org/abs/2006.05682>`_
"""
def __init__(self,
backbone,
neck=None,
rpn_head=None,
roi_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
super(H3DNet, self).__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
roi_head=roi_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained,
init_cfg=init_cfg)
[docs] def forward_train(self,
points,
img_metas,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
gt_bboxes_ignore=None):
"""Forward of training.
Args:
points (list[torch.Tensor]): Points of each batch.
img_metas (list): Image metas.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): gt class labels of each batch.
pts_semantic_mask (list[torch.Tensor]): point-wise semantic
label of each batch.
pts_instance_mask (list[torch.Tensor]): point-wise instance
label of each batch.
gt_bboxes_ignore (list[torch.Tensor]): Specify
which bounding.
Returns:
dict: Losses.
"""
points_cat = torch.stack(points)
feats_dict = self.extract_feat(points_cat)
feats_dict['fp_xyz'] = [feats_dict['fp_xyz_net0'][-1]]
feats_dict['fp_features'] = [feats_dict['hd_feature']]
feats_dict['fp_indices'] = [feats_dict['fp_indices_net0'][-1]]
losses = dict()
if self.with_rpn:
rpn_outs = self.rpn_head(feats_dict, self.train_cfg.rpn.sample_mod)
feats_dict.update(rpn_outs)
rpn_loss_inputs = (points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask, img_metas)
rpn_losses = self.rpn_head.loss(
rpn_outs,
*rpn_loss_inputs,
gt_bboxes_ignore=gt_bboxes_ignore,
ret_target=True)
feats_dict['targets'] = rpn_losses.pop('targets')
losses.update(rpn_losses)
# Generate rpn proposals
proposal_cfg = self.train_cfg.get('rpn_proposal',
self.test_cfg.rpn)
proposal_inputs = (points, rpn_outs, img_metas)
proposal_list = self.rpn_head.get_bboxes(
*proposal_inputs, use_nms=proposal_cfg.use_nms)
feats_dict['proposal_list'] = proposal_list
else:
raise NotImplementedError
roi_losses = self.roi_head.forward_train(feats_dict, img_metas, points,
gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask,
pts_instance_mask,
gt_bboxes_ignore)
losses.update(roi_losses)
return losses
[docs] def simple_test(self, points, img_metas, imgs=None, rescale=False):
"""Forward of testing.
Args:
points (list[torch.Tensor]): Points of each sample.
img_metas (list): Image metas.
rescale (bool): Whether to rescale results.
Returns:
list: Predicted 3d boxes.
"""
points_cat = torch.stack(points)
feats_dict = self.extract_feat(points_cat)
feats_dict['fp_xyz'] = [feats_dict['fp_xyz_net0'][-1]]
feats_dict['fp_features'] = [feats_dict['hd_feature']]
feats_dict['fp_indices'] = [feats_dict['fp_indices_net0'][-1]]
if self.with_rpn:
proposal_cfg = self.test_cfg.rpn
rpn_outs = self.rpn_head(feats_dict, proposal_cfg.sample_mod)
feats_dict.update(rpn_outs)
# Generate rpn proposals
proposal_list = self.rpn_head.get_bboxes(
points, rpn_outs, img_metas, use_nms=proposal_cfg.use_nms)
feats_dict['proposal_list'] = proposal_list
else:
raise NotImplementedError
return self.roi_head.simple_test(
feats_dict, img_metas, points_cat, rescale=rescale)
[docs] def aug_test(self, points, img_metas, imgs=None, rescale=False):
"""Test with augmentation."""
points_cat = [torch.stack(pts) for pts in points]
feats_dict = self.extract_feats(points_cat, img_metas)
for feat_dict in feats_dict:
feat_dict['fp_xyz'] = [feat_dict['fp_xyz_net0'][-1]]
feat_dict['fp_features'] = [feat_dict['hd_feature']]
feat_dict['fp_indices'] = [feat_dict['fp_indices_net0'][-1]]
# only support aug_test for one sample
aug_bboxes = []
for feat_dict, pts_cat, img_meta in zip(feats_dict, points_cat,
img_metas):
if self.with_rpn:
proposal_cfg = self.test_cfg.rpn
rpn_outs = self.rpn_head(feat_dict, proposal_cfg.sample_mod)
feat_dict.update(rpn_outs)
# Generate rpn proposals
proposal_list = self.rpn_head.get_bboxes(
points, rpn_outs, img_metas, use_nms=proposal_cfg.use_nms)
feat_dict['proposal_list'] = proposal_list
else:
raise NotImplementedError
bbox_results = self.roi_head.simple_test(
feat_dict,
self.test_cfg.rcnn.sample_mod,
img_meta,
pts_cat,
rescale=rescale)
aug_bboxes.append(bbox_results)
# after merging, bboxes will be rescaled to the original image size
merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
self.bbox_head.test_cfg)
return [merged_bboxes]
[docs] def extract_feats(self, points, img_metas):
"""Extract features of multiple samples."""
return [
self.extract_feat(pts, img_meta)
for pts, img_meta in zip(points, img_metas)
]