Shortcuts

Source code for mmdet3d.models.roi_heads.h3d_roi_head

# Copyright (c) OpenMMLab. All rights reserved.
from mmdet3d.core.bbox import bbox3d2result
from ..builder import HEADS, build_head
from .base_3droi_head import Base3DRoIHead


[docs]@HEADS.register_module() class H3DRoIHead(Base3DRoIHead): """H3D roi head for H3DNet. Args: primitive_list (List): Configs of primitive heads. bbox_head (ConfigDict): Config of bbox_head. train_cfg (ConfigDict): Training config. test_cfg (ConfigDict): Testing config. """ def __init__(self, primitive_list, bbox_head=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None): super(H3DRoIHead, self).__init__( bbox_head=bbox_head, train_cfg=train_cfg, test_cfg=test_cfg, pretrained=pretrained, init_cfg=init_cfg) # Primitive module assert len(primitive_list) == 3 self.primitive_z = build_head(primitive_list[0]) self.primitive_xy = build_head(primitive_list[1]) self.primitive_line = build_head(primitive_list[2])
[docs] def init_mask_head(self): """Initialize mask head, skip since ``H3DROIHead`` does not have one.""" pass
[docs] def init_bbox_head(self, bbox_head): """Initialize box head.""" bbox_head['train_cfg'] = self.train_cfg bbox_head['test_cfg'] = self.test_cfg self.bbox_head = build_head(bbox_head)
[docs] def init_assigner_sampler(self): """Initialize assigner and sampler.""" pass
[docs] def forward_train(self, feats_dict, img_metas, points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask, pts_instance_mask, gt_bboxes_ignore=None): """Training forward function of PartAggregationROIHead. Args: feats_dict (dict): Contains features from the first stage. img_metas (list[dict]): Contain pcd and img's meta info. points (list[torch.Tensor]): Input points. gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth bboxes of each sample. gt_labels_3d (list[torch.Tensor]): Labels of each sample. pts_semantic_mask (list[torch.Tensor]): Point-wise semantic mask. pts_instance_mask (list[torch.Tensor]): Point-wise instance mask. gt_bboxes_ignore (list[torch.Tensor]): Specify which bounding boxes to ignore. Returns: dict: losses from each head. """ losses = dict() sample_mod = self.train_cfg.sample_mod assert sample_mod in ['vote', 'seed', 'random'] result_z = self.primitive_z(feats_dict, sample_mod) feats_dict.update(result_z) result_xy = self.primitive_xy(feats_dict, sample_mod) feats_dict.update(result_xy) result_line = self.primitive_line(feats_dict, sample_mod) feats_dict.update(result_line) primitive_loss_inputs = (feats_dict, points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask, pts_instance_mask, img_metas, gt_bboxes_ignore) loss_z = self.primitive_z.loss(*primitive_loss_inputs) losses.update(loss_z) loss_xy = self.primitive_xy.loss(*primitive_loss_inputs) losses.update(loss_xy) loss_line = self.primitive_line.loss(*primitive_loss_inputs) losses.update(loss_line) targets = feats_dict.pop('targets') bbox_results = self.bbox_head(feats_dict, sample_mod) feats_dict.update(bbox_results) bbox_loss = self.bbox_head.loss(feats_dict, points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask, pts_instance_mask, img_metas, targets, gt_bboxes_ignore) losses.update(bbox_loss) return losses
[docs] def simple_test(self, feats_dict, img_metas, points, rescale=False): """Simple testing forward function of PartAggregationROIHead. Note: This function assumes that the batch size is 1 Args: feats_dict (dict): Contains features from the first stage. img_metas (list[dict]): Contain pcd and img's meta info. points (torch.Tensor): Input points. rescale (bool): Whether to rescale results. Returns: dict: Bbox results of one frame. """ sample_mod = self.test_cfg.sample_mod assert sample_mod in ['vote', 'seed', 'random'] result_z = self.primitive_z(feats_dict, sample_mod) feats_dict.update(result_z) result_xy = self.primitive_xy(feats_dict, sample_mod) feats_dict.update(result_xy) result_line = self.primitive_line(feats_dict, sample_mod) feats_dict.update(result_line) bbox_preds = self.bbox_head(feats_dict, sample_mod) feats_dict.update(bbox_preds) bbox_list = self.bbox_head.get_bboxes( points, feats_dict, img_metas, rescale=rescale, suffix='_optimized') bbox_results = [ bbox3d2result(bboxes, scores, labels) for bboxes, scores, labels in bbox_list ] return bbox_results
Read the Docs v: dev
Versions
latest
stable
v1.0.0rc1
v1.0.0rc0
v0.18.1
v0.18.0
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.