Shortcuts

Source code for mmdet3d.models.detectors.imvoxelnet

# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmdet3d.core import bbox3d2result, build_prior_generator
from mmdet3d.models.fusion_layers.point_fusion import point_sample
from mmdet.models import DETECTORS, build_backbone, build_head, build_neck
from mmdet.models.detectors import BaseDetector


[docs]@DETECTORS.register_module() class ImVoxelNet(BaseDetector): r"""`ImVoxelNet <https://arxiv.org/abs/2106.01178>`_.""" def __init__(self, backbone, neck, neck_3d, bbox_head, n_voxels, anchor_generator, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None): super().__init__(init_cfg=init_cfg) self.backbone = build_backbone(backbone) self.neck = build_neck(neck) self.neck_3d = build_neck(neck_3d) bbox_head.update(train_cfg=train_cfg) bbox_head.update(test_cfg=test_cfg) self.bbox_head = build_head(bbox_head) self.n_voxels = n_voxels self.anchor_generator = build_prior_generator(anchor_generator) self.train_cfg = train_cfg self.test_cfg = test_cfg
[docs] def extract_feat(self, img, img_metas): """Extract 3d features from the backbone -> fpn -> 3d projection. Args: img (torch.Tensor): Input images of shape (N, C_in, H, W). img_metas (list): Image metas. Returns: torch.Tensor: of shape (N, C_out, N_x, N_y, N_z) """ x = self.backbone(img) x = self.neck(x)[0] points = self.anchor_generator.grid_anchors( [self.n_voxels[::-1]], device=img.device)[0][:, :3] volumes = [] for feature, img_meta in zip(x, img_metas): img_scale_factor = ( points.new_tensor(img_meta['scale_factor'][:2]) if 'scale_factor' in img_meta.keys() else 1) img_flip = img_meta['flip'] if 'flip' in img_meta.keys() else False img_crop_offset = ( points.new_tensor(img_meta['img_crop_offset']) if 'img_crop_offset' in img_meta.keys() else 0) volume = point_sample( img_meta, img_features=feature[None, ...], points=points, proj_mat=points.new_tensor(img_meta['lidar2img']), coord_type='LIDAR', img_scale_factor=img_scale_factor, img_crop_offset=img_crop_offset, img_flip=img_flip, img_pad_shape=img.shape[-2:], img_shape=img_meta['img_shape'][:2], aligned=False) volumes.append( volume.reshape(self.n_voxels[::-1] + [-1]).permute(3, 2, 1, 0)) x = torch.stack(volumes) x = self.neck_3d(x) return x
[docs] def forward_train(self, img, img_metas, gt_bboxes_3d, gt_labels_3d, **kwargs): """Forward of training. Args: img (torch.Tensor): Input images of shape (N, C_in, H, W). img_metas (list): Image metas. gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch. gt_labels_3d (list[torch.Tensor]): gt class labels of each batch. Returns: dict[str, torch.Tensor]: A dictionary of loss components. """ x = self.extract_feat(img, img_metas) x = self.bbox_head(x) losses = self.bbox_head.loss(*x, gt_bboxes_3d, gt_labels_3d, img_metas) return losses
[docs] def forward_test(self, img, img_metas, **kwargs): """Forward of testing. Args: img (torch.Tensor): Input images of shape (N, C_in, H, W). img_metas (list): Image metas. Returns: list[dict]: Predicted 3d boxes. """ # not supporting aug_test for now return self.simple_test(img, img_metas)
[docs] def simple_test(self, img, img_metas): """Test without augmentations. Args: img (torch.Tensor): Input images of shape (N, C_in, H, W). img_metas (list): Image metas. Returns: list[dict]: Predicted 3d boxes. """ x = self.extract_feat(img, img_metas) x = self.bbox_head(x) bbox_list = self.bbox_head.get_bboxes(*x, img_metas) bbox_results = [ bbox3d2result(det_bboxes, det_scores, det_labels) for det_bboxes, det_scores, det_labels in bbox_list ] return bbox_results
[docs] def aug_test(self, imgs, img_metas, **kwargs): """Test with augmentations. Args: imgs (list[torch.Tensor]): Input images of shape (N, C_in, H, W). img_metas (list): Image metas. Returns: list[dict]: Predicted 3d boxes. """ raise NotImplementedError
Read the Docs v: v0.17.3
Versions
latest
stable
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.