import torch
from mmcv.runner import force_fp32
from torch.nn import functional as F
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from mmdet3d.ops import Voxelization
from mmdet.models import DETECTORS
from .. import builder
from .single_stage import SingleStage3DDetector
[docs]@DETECTORS.register_module()
class VoxelNet(SingleStage3DDetector):
r"""`VoxelNet <https://arxiv.org/abs/1711.06396>`_ for 3D detection."""
def __init__(self,
voxel_layer,
voxel_encoder,
middle_encoder,
backbone,
neck=None,
bbox_head=None,
train_cfg=None,
test_cfg=None,
init_cfg=None,
pretrained=None):
super(VoxelNet, self).__init__(
backbone=backbone,
neck=neck,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg,
pretrained=pretrained)
self.voxel_layer = Voxelization(**voxel_layer)
self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder)
self.middle_encoder = builder.build_middle_encoder(middle_encoder)
[docs] @torch.no_grad()
@force_fp32()
def voxelize(self, points):
"""Apply hard voxelization to points."""
voxels, coors, num_points = [], [], []
for res in points:
res_voxels, res_coors, res_num_points = self.voxel_layer(res)
voxels.append(res_voxels)
coors.append(res_coors)
num_points.append(res_num_points)
voxels = torch.cat(voxels, dim=0)
num_points = torch.cat(num_points, dim=0)
coors_batch = []
for i, coor in enumerate(coors):
coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
coors_batch.append(coor_pad)
coors_batch = torch.cat(coors_batch, dim=0)
return voxels, num_points, coors_batch
[docs] def forward_train(self,
points,
img_metas,
gt_bboxes_3d,
gt_labels_3d,
gt_bboxes_ignore=None):
"""Training forward function.
Args:
points (list[torch.Tensor]): Point cloud of each sample.
img_metas (list[dict]): Meta information of each sample
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
boxes for each sample.
gt_labels_3d (list[torch.Tensor]): Ground truth labels for
boxes of each sampole
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
boxes to be ignored. Defaults to None.
Returns:
dict: Losses of each branch.
"""
x = self.extract_feat(points, img_metas)
outs = self.bbox_head(x)
loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_metas)
losses = self.bbox_head.loss(
*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
return losses
[docs] def simple_test(self, points, img_metas, imgs=None, rescale=False):
"""Test function without augmentaiton."""
x = self.extract_feat(points, img_metas)
outs = self.bbox_head(x)
bbox_list = self.bbox_head.get_bboxes(
*outs, img_metas, rescale=rescale)
bbox_results = [
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
return bbox_results
[docs] def aug_test(self, points, img_metas, imgs=None, rescale=False):
"""Test function with augmentaiton."""
feats = self.extract_feats(points, img_metas)
# only support aug_test for one sample
aug_bboxes = []
for x, img_meta in zip(feats, img_metas):
outs = self.bbox_head(x)
bbox_list = self.bbox_head.get_bboxes(
*outs, img_meta, rescale=rescale)
bbox_list = [
dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels)
for bboxes, scores, labels in bbox_list
]
aug_bboxes.append(bbox_list[0])
# after merging, bboxes will be rescaled to the original image size
merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
self.bbox_head.test_cfg)
return [merged_bboxes]