import numpy as np
import torch
from mmcv.runner import BaseModule, force_fp32
from torch.nn import functional as F
from mmdet3d.core.post_processing import aligned_3d_nms
from mmdet3d.models.builder import build_loss
from mmdet3d.models.losses import chamfer_distance
from mmdet3d.models.model_utils import VoteModule
from mmdet3d.ops import build_sa_module, furthest_point_sample
from mmdet.core import build_bbox_coder, multi_apply
from mmdet.models import HEADS
from .base_conv_bbox_head import BaseConvBboxHead
[docs]@HEADS.register_module()
class VoteHead(BaseModule):
r"""Bbox head of `Votenet <https://arxiv.org/abs/1904.09664>`_.
Args:
num_classes (int): The number of class.
bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and
decoding boxes.
train_cfg (dict): Config for training.
test_cfg (dict): Config for testing.
vote_module_cfg (dict): Config of VoteModule for point-wise votes.
vote_aggregation_cfg (dict): Config of vote aggregation layer.
pred_layer_cfg (dict): Config of classfication and regression
prediction layers.
conv_cfg (dict): Config of convolution in prediction layer.
norm_cfg (dict): Config of BN in prediction layer.
objectness_loss (dict): Config of objectness loss.
center_loss (dict): Config of center loss.
dir_class_loss (dict): Config of direction classification loss.
dir_res_loss (dict): Config of direction residual regression loss.
size_class_loss (dict): Config of size classification loss.
size_res_loss (dict): Config of size residual regression loss.
semantic_loss (dict): Config of point-wise semantic segmentation loss.
"""
def __init__(self,
num_classes,
bbox_coder,
train_cfg=None,
test_cfg=None,
vote_module_cfg=None,
vote_aggregation_cfg=None,
pred_layer_cfg=None,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=None,
center_loss=None,
dir_class_loss=None,
dir_res_loss=None,
size_class_loss=None,
size_res_loss=None,
semantic_loss=None,
iou_loss=None,
init_cfg=None):
super(VoteHead, self).__init__(init_cfg=init_cfg)
self.num_classes = num_classes
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.gt_per_seed = vote_module_cfg['gt_per_seed']
self.num_proposal = vote_aggregation_cfg['num_point']
self.objectness_loss = build_loss(objectness_loss)
self.center_loss = build_loss(center_loss)
self.dir_res_loss = build_loss(dir_res_loss)
self.dir_class_loss = build_loss(dir_class_loss)
self.size_res_loss = build_loss(size_res_loss)
if size_class_loss is not None:
self.size_class_loss = build_loss(size_class_loss)
if semantic_loss is not None:
self.semantic_loss = build_loss(semantic_loss)
if iou_loss is not None:
self.iou_loss = build_loss(iou_loss)
else:
self.iou_loss = None
self.bbox_coder = build_bbox_coder(bbox_coder)
self.num_sizes = self.bbox_coder.num_sizes
self.num_dir_bins = self.bbox_coder.num_dir_bins
self.vote_module = VoteModule(**vote_module_cfg)
self.vote_aggregation = build_sa_module(vote_aggregation_cfg)
self.fp16_enabled = False
# Bbox classification and regression
self.conv_pred = BaseConvBboxHead(
**pred_layer_cfg,
num_cls_out_channels=self._get_cls_out_channels(),
num_reg_out_channels=self._get_reg_out_channels())
def _get_cls_out_channels(self):
"""Return the channel number of classification outputs."""
# Class numbers (k) + objectness (2)
return self.num_classes + 2
def _get_reg_out_channels(self):
"""Return the channel number of regression outputs."""
# Objectness scores (2), center residual (3),
# heading class+residual (num_dir_bins*2),
# size class+residual(num_sizes*4)
return 3 + self.num_dir_bins * 2 + self.num_sizes * 4
def _extract_input(self, feat_dict):
"""Extract inputs from features dictionary.
Args:
feat_dict (dict): Feature dict from backbone.
Returns:
torch.Tensor: Coordinates of input points.
torch.Tensor: Features of input points.
torch.Tensor: Indices of input points.
"""
# for imvotenet
if 'seed_points' in feat_dict and \
'seed_features' in feat_dict and \
'seed_indices' in feat_dict:
seed_points = feat_dict['seed_points']
seed_features = feat_dict['seed_features']
seed_indices = feat_dict['seed_indices']
# for votenet
else:
seed_points = feat_dict['fp_xyz'][-1]
seed_features = feat_dict['fp_features'][-1]
seed_indices = feat_dict['fp_indices'][-1]
return seed_points, seed_features, seed_indices
[docs] def forward(self, feat_dict, sample_mod):
"""Forward pass.
Note:
The forward of VoteHead is devided into 4 steps:
1. Generate vote_points from seed_points.
2. Aggregate vote_points.
3. Predict bbox and score.
4. Decode predictions.
Args:
feat_dict (dict): Feature dict from backbone.
sample_mod (str): Sample mode for vote aggregation layer.
valid modes are "vote", "seed", "random" and "spec".
Returns:
dict: Predictions of vote head.
"""
assert sample_mod in ['vote', 'seed', 'random', 'spec']
seed_points, seed_features, seed_indices = self._extract_input(
feat_dict)
# 1. generate vote_points from seed_points
vote_points, vote_features, vote_offset = self.vote_module(
seed_points, seed_features)
results = dict(
seed_points=seed_points,
seed_indices=seed_indices,
vote_points=vote_points,
vote_features=vote_features,
vote_offset=vote_offset)
# 2. aggregate vote_points
if sample_mod == 'vote':
# use fps in vote_aggregation
aggregation_inputs = dict(
points_xyz=vote_points, features=vote_features)
elif sample_mod == 'seed':
# FPS on seed and choose the votes corresponding to the seeds
sample_indices = furthest_point_sample(seed_points,
self.num_proposal)
aggregation_inputs = dict(
points_xyz=vote_points,
features=vote_features,
indices=sample_indices)
elif sample_mod == 'random':
# Random sampling from the votes
batch_size, num_seed = seed_points.shape[:2]
sample_indices = seed_points.new_tensor(
torch.randint(0, num_seed, (batch_size, self.num_proposal)),
dtype=torch.int32)
aggregation_inputs = dict(
points_xyz=vote_points,
features=vote_features,
indices=sample_indices)
elif sample_mod == 'spec':
# Specify the new center in vote_aggregation
aggregation_inputs = dict(
points_xyz=seed_points,
features=seed_features,
target_xyz=vote_points)
else:
raise NotImplementedError(
f'Sample mode {sample_mod} is not supported!')
vote_aggregation_ret = self.vote_aggregation(**aggregation_inputs)
aggregated_points, features, aggregated_indices = vote_aggregation_ret
results['aggregated_points'] = aggregated_points
results['aggregated_features'] = features
results['aggregated_indices'] = aggregated_indices
# 3. predict bbox and score
cls_predictions, reg_predictions = self.conv_pred(features)
# 4. decode predictions
decode_res = self.bbox_coder.split_pred(cls_predictions,
reg_predictions,
aggregated_points)
results.update(decode_res)
return results
[docs] @force_fp32(apply_to=('bbox_preds', ))
def loss(self,
bbox_preds,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
img_metas=None,
gt_bboxes_ignore=None,
ret_target=False):
"""Compute loss.
Args:
bbox_preds (dict): Predictions from forward of vote head.
points (list[torch.Tensor]): Input points.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \
bboxes of each sample.
gt_labels_3d (list[torch.Tensor]): Labels of each sample.
pts_semantic_mask (None | list[torch.Tensor]): Point-wise
semantic mask.
pts_instance_mask (None | list[torch.Tensor]): Point-wise
instance mask.
img_metas (list[dict]): Contain pcd and img's meta info.
gt_bboxes_ignore (None | list[torch.Tensor]): Specify
which bounding.
ret_target (Bool): Return targets or not.
Returns:
dict: Losses of Votenet.
"""
targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask,
bbox_preds)
(vote_targets, vote_target_masks, size_class_targets, size_res_targets,
dir_class_targets, dir_res_targets, center_targets,
assigned_center_targets, mask_targets, valid_gt_masks,
objectness_targets, objectness_weights, box_loss_weights,
valid_gt_weights) = targets
# calculate vote loss
vote_loss = self.vote_module.get_loss(bbox_preds['seed_points'],
bbox_preds['vote_points'],
bbox_preds['seed_indices'],
vote_target_masks, vote_targets)
# calculate objectness loss
objectness_loss = self.objectness_loss(
bbox_preds['obj_scores'].transpose(2, 1),
objectness_targets,
weight=objectness_weights)
# calculate center loss
source2target_loss, target2source_loss = self.center_loss(
bbox_preds['center'],
center_targets,
src_weight=box_loss_weights,
dst_weight=valid_gt_weights)
center_loss = source2target_loss + target2source_loss
# calculate direction class loss
dir_class_loss = self.dir_class_loss(
bbox_preds['dir_class'].transpose(2, 1),
dir_class_targets,
weight=box_loss_weights)
# calculate direction residual loss
batch_size, proposal_num = size_class_targets.shape[:2]
heading_label_one_hot = vote_targets.new_zeros(
(batch_size, proposal_num, self.num_dir_bins))
heading_label_one_hot.scatter_(2, dir_class_targets.unsqueeze(-1), 1)
dir_res_norm = torch.sum(
bbox_preds['dir_res_norm'] * heading_label_one_hot, -1)
dir_res_loss = self.dir_res_loss(
dir_res_norm, dir_res_targets, weight=box_loss_weights)
# calculate size class loss
size_class_loss = self.size_class_loss(
bbox_preds['size_class'].transpose(2, 1),
size_class_targets,
weight=box_loss_weights)
# calculate size residual loss
one_hot_size_targets = vote_targets.new_zeros(
(batch_size, proposal_num, self.num_sizes))
one_hot_size_targets.scatter_(2, size_class_targets.unsqueeze(-1), 1)
one_hot_size_targets_expand = one_hot_size_targets.unsqueeze(
-1).repeat(1, 1, 1, 3).contiguous()
size_residual_norm = torch.sum(
bbox_preds['size_res_norm'] * one_hot_size_targets_expand, 2)
box_loss_weights_expand = box_loss_weights.unsqueeze(-1).repeat(
1, 1, 3)
size_res_loss = self.size_res_loss(
size_residual_norm,
size_res_targets,
weight=box_loss_weights_expand)
# calculate semantic loss
semantic_loss = self.semantic_loss(
bbox_preds['sem_scores'].transpose(2, 1),
mask_targets,
weight=box_loss_weights)
losses = dict(
vote_loss=vote_loss,
objectness_loss=objectness_loss,
semantic_loss=semantic_loss,
center_loss=center_loss,
dir_class_loss=dir_class_loss,
dir_res_loss=dir_res_loss,
size_class_loss=size_class_loss,
size_res_loss=size_res_loss)
if self.iou_loss:
corners_pred = self.bbox_coder.decode_corners(
bbox_preds['center'], size_residual_norm,
one_hot_size_targets_expand)
corners_target = self.bbox_coder.decode_corners(
assigned_center_targets, size_res_targets,
one_hot_size_targets_expand)
iou_loss = self.iou_loss(
corners_pred, corners_target, weight=box_loss_weights)
losses['iou_loss'] = iou_loss
if ret_target:
losses['targets'] = targets
return losses
[docs] def get_targets(self,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
bbox_preds=None):
"""Generate targets of vote head.
Args:
points (list[torch.Tensor]): Points of each batch.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \
bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): Labels of each batch.
pts_semantic_mask (None | list[torch.Tensor]): Point-wise semantic
label of each batch.
pts_instance_mask (None | list[torch.Tensor]): Point-wise instance
label of each batch.
bbox_preds (torch.Tensor): Bounding box predictions of vote head.
Returns:
tuple[torch.Tensor]: Targets of vote head.
"""
# find empty example
valid_gt_masks = list()
gt_num = list()
for index in range(len(gt_labels_3d)):
if len(gt_labels_3d[index]) == 0:
fake_box = gt_bboxes_3d[index].tensor.new_zeros(
1, gt_bboxes_3d[index].tensor.shape[-1])
gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box)
gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1)
valid_gt_masks.append(gt_labels_3d[index].new_zeros(1))
gt_num.append(1)
else:
valid_gt_masks.append(gt_labels_3d[index].new_ones(
gt_labels_3d[index].shape))
gt_num.append(gt_labels_3d[index].shape[0])
max_gt_num = max(gt_num)
if pts_semantic_mask is None:
pts_semantic_mask = [None for i in range(len(gt_labels_3d))]
pts_instance_mask = [None for i in range(len(gt_labels_3d))]
aggregated_points = [
bbox_preds['aggregated_points'][i]
for i in range(len(gt_labels_3d))
]
(vote_targets, vote_target_masks, size_class_targets, size_res_targets,
dir_class_targets, dir_res_targets, center_targets,
assigned_center_targets, mask_targets, objectness_targets,
objectness_masks) = multi_apply(self.get_targets_single, points,
gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask,
aggregated_points)
# pad targets as original code of votenet.
for index in range(len(gt_labels_3d)):
pad_num = max_gt_num - gt_labels_3d[index].shape[0]
center_targets[index] = F.pad(center_targets[index],
(0, 0, 0, pad_num))
valid_gt_masks[index] = F.pad(valid_gt_masks[index], (0, pad_num))
vote_targets = torch.stack(vote_targets)
vote_target_masks = torch.stack(vote_target_masks)
center_targets = torch.stack(center_targets)
valid_gt_masks = torch.stack(valid_gt_masks)
assigned_center_targets = torch.stack(assigned_center_targets)
objectness_targets = torch.stack(objectness_targets)
objectness_weights = torch.stack(objectness_masks)
objectness_weights /= (torch.sum(objectness_weights) + 1e-6)
box_loss_weights = objectness_targets.float() / (
torch.sum(objectness_targets).float() + 1e-6)
valid_gt_weights = valid_gt_masks.float() / (
torch.sum(valid_gt_masks.float()) + 1e-6)
dir_class_targets = torch.stack(dir_class_targets)
dir_res_targets = torch.stack(dir_res_targets)
size_class_targets = torch.stack(size_class_targets)
size_res_targets = torch.stack(size_res_targets)
mask_targets = torch.stack(mask_targets)
return (vote_targets, vote_target_masks, size_class_targets,
size_res_targets, dir_class_targets, dir_res_targets,
center_targets, assigned_center_targets, mask_targets,
valid_gt_masks, objectness_targets, objectness_weights,
box_loss_weights, valid_gt_weights)
[docs] def get_targets_single(self,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
aggregated_points=None):
"""Generate targets of vote head for single batch.
Args:
points (torch.Tensor): Points of each batch.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth \
boxes of each batch.
gt_labels_3d (torch.Tensor): Labels of each batch.
pts_semantic_mask (None | torch.Tensor): Point-wise semantic
label of each batch.
pts_instance_mask (None | torch.Tensor): Point-wise instance
label of each batch.
aggregated_points (torch.Tensor): Aggregated points from
vote aggregation layer.
Returns:
tuple[torch.Tensor]: Targets of vote head.
"""
assert self.bbox_coder.with_rot or pts_semantic_mask is not None
gt_bboxes_3d = gt_bboxes_3d.to(points.device)
# generate votes target
num_points = points.shape[0]
if self.bbox_coder.with_rot:
vote_targets = points.new_zeros([num_points, 3 * self.gt_per_seed])
vote_target_masks = points.new_zeros([num_points],
dtype=torch.long)
vote_target_idx = points.new_zeros([num_points], dtype=torch.long)
box_indices_all = gt_bboxes_3d.points_in_boxes(points)
for i in range(gt_labels_3d.shape[0]):
box_indices = box_indices_all[:, i]
indices = torch.nonzero(
box_indices, as_tuple=False).squeeze(-1)
selected_points = points[indices]
vote_target_masks[indices] = 1
vote_targets_tmp = vote_targets[indices]
votes = gt_bboxes_3d.gravity_center[i].unsqueeze(
0) - selected_points[:, :3]
for j in range(self.gt_per_seed):
column_indices = torch.nonzero(
vote_target_idx[indices] == j,
as_tuple=False).squeeze(-1)
vote_targets_tmp[column_indices,
int(j * 3):int(j * 3 +
3)] = votes[column_indices]
if j == 0:
vote_targets_tmp[column_indices] = votes[
column_indices].repeat(1, self.gt_per_seed)
vote_targets[indices] = vote_targets_tmp
vote_target_idx[indices] = torch.clamp(
vote_target_idx[indices] + 1, max=2)
elif pts_semantic_mask is not None:
vote_targets = points.new_zeros([num_points, 3])
vote_target_masks = points.new_zeros([num_points],
dtype=torch.long)
for i in torch.unique(pts_instance_mask):
indices = torch.nonzero(
pts_instance_mask == i, as_tuple=False).squeeze(-1)
if pts_semantic_mask[indices[0]] < self.num_classes:
selected_points = points[indices, :3]
center = 0.5 * (
selected_points.min(0)[0] + selected_points.max(0)[0])
vote_targets[indices, :] = center - selected_points
vote_target_masks[indices] = 1
vote_targets = vote_targets.repeat((1, self.gt_per_seed))
else:
raise NotImplementedError
(center_targets, size_class_targets, size_res_targets,
dir_class_targets,
dir_res_targets) = self.bbox_coder.encode(gt_bboxes_3d, gt_labels_3d)
proposal_num = aggregated_points.shape[0]
distance1, _, assignment, _ = chamfer_distance(
aggregated_points.unsqueeze(0),
center_targets.unsqueeze(0),
reduction='none')
assignment = assignment.squeeze(0)
euclidean_distance1 = torch.sqrt(distance1.squeeze(0) + 1e-6)
objectness_targets = points.new_zeros((proposal_num), dtype=torch.long)
objectness_targets[
euclidean_distance1 < self.train_cfg['pos_distance_thr']] = 1
objectness_masks = points.new_zeros((proposal_num))
objectness_masks[
euclidean_distance1 < self.train_cfg['pos_distance_thr']] = 1.0
objectness_masks[
euclidean_distance1 > self.train_cfg['neg_distance_thr']] = 1.0
dir_class_targets = dir_class_targets[assignment]
dir_res_targets = dir_res_targets[assignment]
dir_res_targets /= (np.pi / self.num_dir_bins)
size_class_targets = size_class_targets[assignment]
size_res_targets = size_res_targets[assignment]
one_hot_size_targets = gt_bboxes_3d.tensor.new_zeros(
(proposal_num, self.num_sizes))
one_hot_size_targets.scatter_(1, size_class_targets.unsqueeze(-1), 1)
one_hot_size_targets = one_hot_size_targets.unsqueeze(-1).repeat(
1, 1, 3)
mean_sizes = size_res_targets.new_tensor(
self.bbox_coder.mean_sizes).unsqueeze(0)
pos_mean_sizes = torch.sum(one_hot_size_targets * mean_sizes, 1)
size_res_targets /= pos_mean_sizes
mask_targets = gt_labels_3d[assignment]
assigned_center_targets = center_targets[assignment]
return (vote_targets, vote_target_masks, size_class_targets,
size_res_targets, dir_class_targets,
dir_res_targets, center_targets, assigned_center_targets,
mask_targets.long(), objectness_targets, objectness_masks)
[docs] def get_bboxes(self,
points,
bbox_preds,
input_metas,
rescale=False,
use_nms=True):
"""Generate bboxes from vote head predictions.
Args:
points (torch.Tensor): Input points.
bbox_preds (dict): Predictions from vote head.
input_metas (list[dict]): Point cloud and image's meta info.
rescale (bool): Whether to rescale bboxes.
use_nms (bool): Whether to apply NMS, skip nms postprocessing
while using vote head in rpn stage.
Returns:
list[tuple[torch.Tensor]]: Bounding boxes, scores and labels.
"""
# decode boxes
obj_scores = F.softmax(bbox_preds['obj_scores'], dim=-1)[..., -1]
sem_scores = F.softmax(bbox_preds['sem_scores'], dim=-1)
bbox3d = self.bbox_coder.decode(bbox_preds)
if use_nms:
batch_size = bbox3d.shape[0]
results = list()
for b in range(batch_size):
bbox_selected, score_selected, labels = \
self.multiclass_nms_single(obj_scores[b], sem_scores[b],
bbox3d[b], points[b, ..., :3],
input_metas[b])
bbox = input_metas[b]['box_type_3d'](
bbox_selected,
box_dim=bbox_selected.shape[-1],
with_yaw=self.bbox_coder.with_rot)
results.append((bbox, score_selected, labels))
return results
else:
return bbox3d
[docs] def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points,
input_meta):
"""Multi-class nms in single batch.
Args:
obj_scores (torch.Tensor): Objectness score of bounding boxes.
sem_scores (torch.Tensor): semantic class score of bounding boxes.
bbox (torch.Tensor): Predicted bounding boxes.
points (torch.Tensor): Input points.
input_meta (dict): Point cloud and image's meta info.
Returns:
tuple[torch.Tensor]: Bounding boxes, scores and labels.
"""
bbox = input_meta['box_type_3d'](
bbox,
box_dim=bbox.shape[-1],
with_yaw=self.bbox_coder.with_rot,
origin=(0.5, 0.5, 0.5))
box_indices = bbox.points_in_boxes(points)
corner3d = bbox.corners
minmax_box3d = corner3d.new(torch.Size((corner3d.shape[0], 6)))
minmax_box3d[:, :3] = torch.min(corner3d, dim=1)[0]
minmax_box3d[:, 3:] = torch.max(corner3d, dim=1)[0]
nonempty_box_mask = box_indices.T.sum(1) > 5
bbox_classes = torch.argmax(sem_scores, -1)
nms_selected = aligned_3d_nms(minmax_box3d[nonempty_box_mask],
obj_scores[nonempty_box_mask],
bbox_classes[nonempty_box_mask],
self.test_cfg.nms_thr)
# filter empty boxes and boxes with low score
scores_mask = (obj_scores > self.test_cfg.score_thr)
nonempty_box_inds = torch.nonzero(
nonempty_box_mask, as_tuple=False).flatten()
nonempty_mask = torch.zeros_like(bbox_classes).scatter(
0, nonempty_box_inds[nms_selected], 1)
selected = (nonempty_mask.bool() & scores_mask.bool())
if self.test_cfg.per_class_proposal:
bbox_selected, score_selected, labels = [], [], []
for k in range(sem_scores.shape[-1]):
bbox_selected.append(bbox[selected].tensor)
score_selected.append(obj_scores[selected] *
sem_scores[selected][:, k])
labels.append(
torch.zeros_like(bbox_classes[selected]).fill_(k))
bbox_selected = torch.cat(bbox_selected, 0)
score_selected = torch.cat(score_selected, 0)
labels = torch.cat(labels, 0)
else:
bbox_selected = bbox[selected].tensor
score_selected = obj_scores[selected]
labels = bbox_classes[selected]
return bbox_selected, score_selected, labels