Source code for mmdet3d.models.detectors.imvoxelnet
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet3d.core import bbox3d2result, build_prior_generator
from mmdet3d.models.fusion_layers.point_fusion import point_sample
from mmdet.models import DETECTORS, build_backbone, build_head, build_neck
from mmdet.models.detectors import BaseDetector
[docs]@DETECTORS.register_module()
class ImVoxelNet(BaseDetector):
r"""`ImVoxelNet <https://arxiv.org/abs/2106.01178>`_."""
def __init__(self,
backbone,
neck,
neck_3d,
bbox_head,
n_voxels,
anchor_generator,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.backbone = build_backbone(backbone)
self.neck = build_neck(neck)
self.neck_3d = build_neck(neck_3d)
bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg)
self.bbox_head = build_head(bbox_head)
self.n_voxels = n_voxels
self.anchor_generator = build_prior_generator(anchor_generator)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
[docs] def extract_feat(self, img, img_metas):
"""Extract 3d features from the backbone -> fpn -> 3d projection.
Args:
img (torch.Tensor): Input images of shape (N, C_in, H, W).
img_metas (list): Image metas.
Returns:
torch.Tensor: of shape (N, C_out, N_x, N_y, N_z)
"""
x = self.backbone(img)
x = self.neck(x)[0]
points = self.anchor_generator.grid_anchors(
[self.n_voxels[::-1]], device=img.device)[0][:, :3]
volumes = []
for feature, img_meta in zip(x, img_metas):
img_scale_factor = (
points.new_tensor(img_meta['scale_factor'][:2])
if 'scale_factor' in img_meta.keys() else 1)
img_flip = img_meta['flip'] if 'flip' in img_meta.keys() else False
img_crop_offset = (
points.new_tensor(img_meta['img_crop_offset'])
if 'img_crop_offset' in img_meta.keys() else 0)
volume = point_sample(
img_meta,
img_features=feature[None, ...],
points=points,
proj_mat=points.new_tensor(img_meta['lidar2img']),
coord_type='LIDAR',
img_scale_factor=img_scale_factor,
img_crop_offset=img_crop_offset,
img_flip=img_flip,
img_pad_shape=img.shape[-2:],
img_shape=img_meta['img_shape'][:2],
aligned=False)
volumes.append(
volume.reshape(self.n_voxels[::-1] + [-1]).permute(3, 2, 1, 0))
x = torch.stack(volumes)
x = self.neck_3d(x)
return x
[docs] def forward_train(self, img, img_metas, gt_bboxes_3d, gt_labels_3d,
**kwargs):
"""Forward of training.
Args:
img (torch.Tensor): Input images of shape (N, C_in, H, W).
img_metas (list): Image metas.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): gt class labels of each batch.
Returns:
dict[str, torch.Tensor]: A dictionary of loss components.
"""
x = self.extract_feat(img, img_metas)
x = self.bbox_head(x)
losses = self.bbox_head.loss(*x, gt_bboxes_3d, gt_labels_3d, img_metas)
return losses
[docs] def forward_test(self, img, img_metas, **kwargs):
"""Forward of testing.
Args:
img (torch.Tensor): Input images of shape (N, C_in, H, W).
img_metas (list): Image metas.
Returns:
list[dict]: Predicted 3d boxes.
"""
# not supporting aug_test for now
return self.simple_test(img, img_metas)
[docs] def simple_test(self, img, img_metas):
"""Test without augmentations.
Args:
img (torch.Tensor): Input images of shape (N, C_in, H, W).
img_metas (list): Image metas.
Returns:
list[dict]: Predicted 3d boxes.
"""
x = self.extract_feat(img, img_metas)
x = self.bbox_head(x)
bbox_list = self.bbox_head.get_bboxes(*x, img_metas)
bbox_results = [
bbox3d2result(det_bboxes, det_scores, det_labels)
for det_bboxes, det_scores, det_labels in bbox_list
]
return bbox_results
[docs] def aug_test(self, imgs, img_metas, **kwargs):
"""Test with augmentations.
Args:
imgs (list[torch.Tensor]): Input images of shape (N, C_in, H, W).
img_metas (list): Image metas.
Returns:
list[dict]: Predicted 3d boxes.
"""
raise NotImplementedError