Shortcuts

Source code for mmdet3d.models.detectors.groupfree3dnet

# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from ..builder import DETECTORS
from .single_stage import SingleStage3DDetector


[docs]@DETECTORS.register_module() class GroupFree3DNet(SingleStage3DDetector): """`Group-Free 3D <https://arxiv.org/abs/2104.00678>`_.""" def __init__(self, backbone, bbox_head=None, train_cfg=None, test_cfg=None, pretrained=None): super(GroupFree3DNet, self).__init__( backbone=backbone, bbox_head=bbox_head, train_cfg=train_cfg, test_cfg=test_cfg, pretrained=pretrained)
[docs] def forward_train(self, points, img_metas, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask=None, pts_instance_mask=None, gt_bboxes_ignore=None): """Forward of training. Args: points (list[torch.Tensor]): Points of each batch. img_metas (list): Image metas. gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch. gt_labels_3d (list[torch.Tensor]): gt class labels of each batch. pts_semantic_mask (list[torch.Tensor]): point-wise semantic label of each batch. pts_instance_mask (list[torch.Tensor]): point-wise instance label of each batch. gt_bboxes_ignore (list[torch.Tensor]): Specify which bounding. Returns: dict[str: torch.Tensor]: Losses. """ # TODO: refactor votenet series to reduce redundant codes. points_cat = torch.stack(points) x = self.extract_feat(points_cat) bbox_preds = self.bbox_head(x, self.train_cfg.sample_mod) loss_inputs = (points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask, pts_instance_mask, img_metas) losses = self.bbox_head.loss( bbox_preds, *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) return losses
[docs] def simple_test(self, points, img_metas, imgs=None, rescale=False): """Forward of testing. Args: points (list[torch.Tensor]): Points of each sample. img_metas (list): Image metas. rescale (bool): Whether to rescale results. Returns: list: Predicted 3d boxes. """ points_cat = torch.stack(points) x = self.extract_feat(points_cat) bbox_preds = self.bbox_head(x, self.test_cfg.sample_mod) bbox_list = self.bbox_head.get_bboxes( points_cat, bbox_preds, img_metas, rescale=rescale) bbox_results = [ bbox3d2result(bboxes, scores, labels) for bboxes, scores, labels in bbox_list ] return bbox_results
[docs] def aug_test(self, points, img_metas, imgs=None, rescale=False): """Test with augmentation.""" points_cat = [torch.stack(pts) for pts in points] feats = self.extract_feats(points_cat, img_metas) # only support aug_test for one sample aug_bboxes = [] for x, pts_cat, img_meta in zip(feats, points_cat, img_metas): bbox_preds = self.bbox_head(x, self.test_cfg.sample_mod) bbox_list = self.bbox_head.get_bboxes( pts_cat, bbox_preds, img_meta, rescale=rescale) bbox_list = [ dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels) for bboxes, scores, labels in bbox_list ] aug_bboxes.append(bbox_list[0]) # after merging, bboxes will be rescaled to the original image size merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas, self.bbox_head.test_cfg) return [merged_bboxes]
Read the Docs v: dev
Versions
latest
stable
v1.0.0rc1
v1.0.0rc0
v0.18.1
v0.18.0
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.