Source code for mmdet3d.core.post_processing.merge_augs
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet3d.core.post_processing import nms_bev, nms_normal_bev
from ..bbox import bbox3d2result, bbox3d_mapping_back, xywhr2xyxyr
[docs]def merge_aug_bboxes_3d(aug_results, img_metas, test_cfg):
"""Merge augmented detection 3D bboxes and scores.
Args:
aug_results (list[dict]): The dict of detection results.
The dict contains the following keys
- boxes_3d (:obj:`BaseInstance3DBoxes`): Detection bbox.
- scores_3d (torch.Tensor): Detection scores.
- labels_3d (torch.Tensor): Predicted box labels.
img_metas (list[dict]): Meta information of each sample.
test_cfg (dict): Test config.
Returns:
dict: Bounding boxes results in cpu mode, containing merged results.
- boxes_3d (:obj:`BaseInstance3DBoxes`): Merged detection bbox.
- scores_3d (torch.Tensor): Merged detection scores.
- labels_3d (torch.Tensor): Merged predicted box labels.
"""
assert len(aug_results) == len(img_metas), \
'"aug_results" should have the same length as "img_metas", got len(' \
f'aug_results)={len(aug_results)} and len(img_metas)={len(img_metas)}'
recovered_bboxes = []
recovered_scores = []
recovered_labels = []
for bboxes, img_info in zip(aug_results, img_metas):
scale_factor = img_info[0]['pcd_scale_factor']
pcd_horizontal_flip = img_info[0]['pcd_horizontal_flip']
pcd_vertical_flip = img_info[0]['pcd_vertical_flip']
recovered_scores.append(bboxes['scores_3d'])
recovered_labels.append(bboxes['labels_3d'])
bboxes = bbox3d_mapping_back(bboxes['boxes_3d'], scale_factor,
pcd_horizontal_flip, pcd_vertical_flip)
recovered_bboxes.append(bboxes)
aug_bboxes = recovered_bboxes[0].cat(recovered_bboxes)
aug_bboxes_for_nms = xywhr2xyxyr(aug_bboxes.bev)
aug_scores = torch.cat(recovered_scores, dim=0)
aug_labels = torch.cat(recovered_labels, dim=0)
# TODO: use a more elegent way to deal with nms
if test_cfg.use_rotate_nms:
nms_func = nms_bev
else:
nms_func = nms_normal_bev
merged_bboxes = []
merged_scores = []
merged_labels = []
# Apply multi-class nms when merge bboxes
if len(aug_labels) == 0:
return bbox3d2result(aug_bboxes, aug_scores, aug_labels)
for class_id in range(torch.max(aug_labels).item() + 1):
class_inds = (aug_labels == class_id)
bboxes_i = aug_bboxes[class_inds]
bboxes_nms_i = aug_bboxes_for_nms[class_inds, :]
scores_i = aug_scores[class_inds]
labels_i = aug_labels[class_inds]
if len(bboxes_nms_i) == 0:
continue
selected = nms_func(bboxes_nms_i, scores_i, test_cfg.nms_thr)
merged_bboxes.append(bboxes_i[selected, :])
merged_scores.append(scores_i[selected])
merged_labels.append(labels_i[selected])
merged_bboxes = merged_bboxes[0].cat(merged_bboxes)
merged_scores = torch.cat(merged_scores, dim=0)
merged_labels = torch.cat(merged_labels, dim=0)
_, order = merged_scores.sort(0, descending=True)
num = min(test_cfg.max_num, len(aug_bboxes))
order = order[:num]
merged_bboxes = merged_bboxes[order]
merged_scores = merged_scores[order]
merged_labels = merged_labels[order]
return bbox3d2result(merged_bboxes, merged_scores, merged_labels)