Source code for mmdet3d.models.detectors.imvoxelnet

import torch

from mmdet3d.core import bbox3d2result, build_anchor_generator
from mmdet3d.models.fusion_layers.point_fusion import point_sample
from mmdet.models import DETECTORS, build_backbone, build_head, build_neck
from mmdet.models.detectors import BaseDetector


[docs]@DETECTORS.register_module() class ImVoxelNet(BaseDetector): r"""`ImVoxelNet <https://arxiv.org/abs/2106.01178>`_.""" def __init__(self, backbone, neck, neck_3d, bbox_head, n_voxels, anchor_generator, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None): super().__init__(init_cfg=init_cfg) self.backbone = build_backbone(backbone) self.neck = build_neck(neck) self.neck_3d = build_neck(neck_3d) bbox_head.update(train_cfg=train_cfg) bbox_head.update(test_cfg=test_cfg) self.bbox_head = build_head(bbox_head) self.n_voxels = n_voxels self.anchor_generator = build_anchor_generator(anchor_generator) self.train_cfg = train_cfg self.test_cfg = test_cfg
[docs] def extract_feat(self, img, img_metas): """Extract 3d features from the backbone -> fpn -> 3d projection. Args: img (torch.Tensor): Input images of shape (N, C_in, H, W). img_metas (list): Image metas. Returns: torch.Tensor: of shape (N, C_out, N_x, N_y, N_z) """ x = self.backbone(img) x = self.neck(x)[0] points = self.anchor_generator.grid_anchors( [self.n_voxels[::-1]], device=img.device)[0][:, :3] volumes = [] for feature, img_meta in zip(x, img_metas): img_scale_factor = ( points.new_tensor(img_meta['scale_factor'][:2]) if 'scale_factor' in img_meta.keys() else 1) img_flip = img_meta['flip'] if 'flip' in img_meta.keys() else False img_crop_offset = ( points.new_tensor(img_meta['img_crop_offset']) if 'img_crop_offset' in img_meta.keys() else 0) volume = point_sample( img_meta, img_features=feature[None, ...], points=points, lidar2img_rt=points.new_tensor(img_meta['lidar2img']), img_scale_factor=img_scale_factor, img_crop_offset=img_crop_offset, img_flip=img_flip, img_pad_shape=img.shape[-2:], img_shape=img_meta['img_shape'][:2], aligned=False) volumes.append( volume.reshape(self.n_voxels[::-1] + [-1]).permute(3, 2, 1, 0)) x = torch.stack(volumes) x = self.neck_3d(x) return x
[docs] def forward_train(self, img, img_metas, gt_bboxes_3d, gt_labels_3d, **kwargs): """Forward of training. Args: img (torch.Tensor): Input images of shape (N, C_in, H, W). img_metas (list): Image metas. gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch. gt_labels_3d (list[torch.Tensor]): gt class labels of each batch. Returns: dict[str, torch.Tensor]: A dictionary of loss components. """ x = self.extract_feat(img, img_metas) x = self.bbox_head(x) losses = self.bbox_head.loss(*x, gt_bboxes_3d, gt_labels_3d, img_metas) return losses
[docs] def forward_test(self, img, img_metas, **kwargs): """Forward of testing. Args: img (torch.Tensor): Input images of shape (N, C_in, H, W). img_metas (list): Image metas. Returns: list[dict]: Predicted 3d boxes. """ # not supporting aug_test for now return self.simple_test(img, img_metas)
[docs] def simple_test(self, img, img_metas): """Test without augmentations. Args: img (torch.Tensor): Input images of shape (N, C_in, H, W). img_metas (list): Image metas. Returns: list[dict]: Predicted 3d boxes. """ x = self.extract_feat(img, img_metas) x = self.bbox_head(x) bbox_list = self.bbox_head.get_bboxes(*x, img_metas) bbox_results = [ bbox3d2result(det_bboxes, det_scores, det_labels) for det_bboxes, det_scores, det_labels in bbox_list ] return bbox_results
[docs] def aug_test(self, imgs, img_metas, **kwargs): """Test with augmentations. Args: imgs (list[torch.Tensor]): Input images of shape (N, C_in, H, W). img_metas (list): Image metas. Returns: list[dict]: Predicted 3d boxes. """ raise NotImplementedError