Source code for mmdet3d.models.detectors.votenet

import torch

from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from mmdet.models import DETECTORS
from .single_stage import SingleStage3DDetector


[docs]@DETECTORS.register_module() class VoteNet(SingleStage3DDetector): r"""`VoteNet <https://arxiv.org/pdf/1904.09664.pdf>`_ for 3D detection.""" def __init__(self, backbone, bbox_head=None, train_cfg=None, test_cfg=None, init_cfg=None, pretrained=None): super(VoteNet, self).__init__( backbone=backbone, bbox_head=bbox_head, train_cfg=train_cfg, test_cfg=test_cfg, init_cfg=None, pretrained=pretrained)
[docs] def forward_train(self, points, img_metas, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask=None, pts_instance_mask=None, gt_bboxes_ignore=None): """Forward of training. Args: points (list[torch.Tensor]): Points of each batch. img_metas (list): Image metas. gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch. gt_labels_3d (list[torch.Tensor]): gt class labels of each batch. pts_semantic_mask (None | list[torch.Tensor]): point-wise semantic label of each batch. pts_instance_mask (None | list[torch.Tensor]): point-wise instance label of each batch. gt_bboxes_ignore (None | list[torch.Tensor]): Specify which bounding. Returns: dict: Losses. """ points_cat = torch.stack(points) x = self.extract_feat(points_cat) bbox_preds = self.bbox_head(x, self.train_cfg.sample_mod) loss_inputs = (points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask, pts_instance_mask, img_metas) losses = self.bbox_head.loss( bbox_preds, *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) return losses
[docs] def simple_test(self, points, img_metas, imgs=None, rescale=False): """Forward of testing. Args: points (list[torch.Tensor]): Points of each sample. img_metas (list): Image metas. rescale (bool): Whether to rescale results. Returns: list: Predicted 3d boxes. """ points_cat = torch.stack(points) x = self.extract_feat(points_cat) bbox_preds = self.bbox_head(x, self.test_cfg.sample_mod) bbox_list = self.bbox_head.get_bboxes( points_cat, bbox_preds, img_metas, rescale=rescale) bbox_results = [ bbox3d2result(bboxes, scores, labels) for bboxes, scores, labels in bbox_list ] return bbox_results
[docs] def aug_test(self, points, img_metas, imgs=None, rescale=False): """Test with augmentation.""" points_cat = [torch.stack(pts) for pts in points] feats = self.extract_feats(points_cat, img_metas) # only support aug_test for one sample aug_bboxes = [] for x, pts_cat, img_meta in zip(feats, points_cat, img_metas): bbox_preds = self.bbox_head(x, self.test_cfg.sample_mod) bbox_list = self.bbox_head.get_bboxes( pts_cat, bbox_preds, img_meta, rescale=rescale) bbox_list = [ dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels) for bboxes, scores, labels in bbox_list ] aug_bboxes.append(bbox_list[0]) # after merging, bboxes will be rescaled to the original image size merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas, self.bbox_head.test_cfg) return [merged_bboxes]