import numpy as np
import torch
from mmcv.ops.nms import batched_nms
from mmcv.runner import force_fp32
from torch.nn import functional as F
from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes,
LiDARInstance3DBoxes,
rotation_3d_in_axis)
from mmdet3d.models.builder import build_loss
from mmdet.core import multi_apply
from mmdet.models import HEADS
from .vote_head import VoteHead
[docs]@HEADS.register_module()
class SSD3DHead(VoteHead):
r"""Bbox head of `3DSSD <https://arxiv.org/abs/2002.10187>`_.
Args:
num_classes (int): The number of class.
bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and
decoding boxes.
in_channels (int): The number of input feature channel.
train_cfg (dict): Config for training.
test_cfg (dict): Config for testing.
vote_module_cfg (dict): Config of VoteModule for point-wise votes.
vote_aggregation_cfg (dict): Config of vote aggregation layer.
pred_layer_cfg (dict): Config of classfication and regression
prediction layers.
conv_cfg (dict): Config of convolution in prediction layer.
norm_cfg (dict): Config of BN in prediction layer.
act_cfg (dict): Config of activation in prediction layer.
objectness_loss (dict): Config of objectness loss.
center_loss (dict): Config of center loss.
dir_class_loss (dict): Config of direction classification loss.
dir_res_loss (dict): Config of direction residual regression loss.
size_res_loss (dict): Config of size residual regression loss.
corner_loss (dict): Config of bbox corners regression loss.
vote_loss (dict): Config of candidate points regression loss.
"""
def __init__(self,
num_classes,
bbox_coder,
in_channels=256,
train_cfg=None,
test_cfg=None,
vote_module_cfg=None,
vote_aggregation_cfg=None,
pred_layer_cfg=None,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU'),
objectness_loss=None,
center_loss=None,
dir_class_loss=None,
dir_res_loss=None,
size_res_loss=None,
corner_loss=None,
vote_loss=None,
init_cfg=None):
super(SSD3DHead, self).__init__(
num_classes,
bbox_coder,
train_cfg=train_cfg,
test_cfg=test_cfg,
vote_module_cfg=vote_module_cfg,
vote_aggregation_cfg=vote_aggregation_cfg,
pred_layer_cfg=pred_layer_cfg,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
objectness_loss=objectness_loss,
center_loss=center_loss,
dir_class_loss=dir_class_loss,
dir_res_loss=dir_res_loss,
size_class_loss=None,
size_res_loss=size_res_loss,
semantic_loss=None,
init_cfg=init_cfg)
self.corner_loss = build_loss(corner_loss)
self.vote_loss = build_loss(vote_loss)
self.num_candidates = vote_module_cfg['num_points']
def _get_cls_out_channels(self):
"""Return the channel number of classification outputs."""
# Class numbers (k) + objectness (1)
return self.num_classes
def _get_reg_out_channels(self):
"""Return the channel number of regression outputs."""
# Bbox classification and regression
# (center residual (3), size regression (3)
# heading class+residual (num_dir_bins*2)),
return 3 + 3 + self.num_dir_bins * 2
def _extract_input(self, feat_dict):
"""Extract inputs from features dictionary.
Args:
feat_dict (dict): Feature dict from backbone.
Returns:
torch.Tensor: Coordinates of input points.
torch.Tensor: Features of input points.
torch.Tensor: Indices of input points.
"""
seed_points = feat_dict['sa_xyz'][-1]
seed_features = feat_dict['sa_features'][-1]
seed_indices = feat_dict['sa_indices'][-1]
return seed_points, seed_features, seed_indices
[docs] @force_fp32(apply_to=('bbox_preds', ))
def loss(self,
bbox_preds,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
img_metas=None,
gt_bboxes_ignore=None):
"""Compute loss.
Args:
bbox_preds (dict): Predictions from forward of SSD3DHead.
points (list[torch.Tensor]): Input points.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \
bboxes of each sample.
gt_labels_3d (list[torch.Tensor]): Labels of each sample.
pts_semantic_mask (None | list[torch.Tensor]): Point-wise
semantic mask.
pts_instance_mask (None | list[torch.Tensor]): Point-wise
instance mask.
img_metas (list[dict]): Contain pcd and img's meta info.
gt_bboxes_ignore (None | list[torch.Tensor]): Specify
which bounding.
Returns:
dict: Losses of 3DSSD.
"""
targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask,
bbox_preds)
(vote_targets, center_targets, size_res_targets, dir_class_targets,
dir_res_targets, mask_targets, centerness_targets, corner3d_targets,
vote_mask, positive_mask, negative_mask, centerness_weights,
box_loss_weights, heading_res_loss_weight) = targets
# calculate centerness loss
centerness_loss = self.objectness_loss(
bbox_preds['obj_scores'].transpose(2, 1),
centerness_targets,
weight=centerness_weights)
# calculate center loss
center_loss = self.center_loss(
bbox_preds['center_offset'],
center_targets,
weight=box_loss_weights.unsqueeze(-1))
# calculate direction class loss
dir_class_loss = self.dir_class_loss(
bbox_preds['dir_class'].transpose(1, 2),
dir_class_targets,
weight=box_loss_weights)
# calculate direction residual loss
dir_res_loss = self.dir_res_loss(
bbox_preds['dir_res_norm'],
dir_res_targets.unsqueeze(-1).repeat(1, 1, self.num_dir_bins),
weight=heading_res_loss_weight)
# calculate size residual loss
size_loss = self.size_res_loss(
bbox_preds['size'],
size_res_targets,
weight=box_loss_weights.unsqueeze(-1))
# calculate corner loss
one_hot_dir_class_targets = dir_class_targets.new_zeros(
bbox_preds['dir_class'].shape)
one_hot_dir_class_targets.scatter_(2, dir_class_targets.unsqueeze(-1),
1)
pred_bbox3d = self.bbox_coder.decode(
dict(
center=bbox_preds['center'],
dir_res=bbox_preds['dir_res'],
dir_class=one_hot_dir_class_targets,
size=bbox_preds['size']))
pred_bbox3d = pred_bbox3d.reshape(-1, pred_bbox3d.shape[-1])
pred_bbox3d = img_metas[0]['box_type_3d'](
pred_bbox3d.clone(),
box_dim=pred_bbox3d.shape[-1],
with_yaw=self.bbox_coder.with_rot,
origin=(0.5, 0.5, 0.5))
pred_corners3d = pred_bbox3d.corners.reshape(-1, 8, 3)
corner_loss = self.corner_loss(
pred_corners3d,
corner3d_targets.reshape(-1, 8, 3),
weight=box_loss_weights.view(-1, 1, 1))
# calculate vote loss
vote_loss = self.vote_loss(
bbox_preds['vote_offset'].transpose(1, 2),
vote_targets,
weight=vote_mask.unsqueeze(-1))
losses = dict(
centerness_loss=centerness_loss,
center_loss=center_loss,
dir_class_loss=dir_class_loss,
dir_res_loss=dir_res_loss,
size_res_loss=size_loss,
corner_loss=corner_loss,
vote_loss=vote_loss)
return losses
[docs] def get_targets(self,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
bbox_preds=None):
"""Generate targets of ssd3d head.
Args:
points (list[torch.Tensor]): Points of each batch.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \
bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): Labels of each batch.
pts_semantic_mask (None | list[torch.Tensor]): Point-wise semantic
label of each batch.
pts_instance_mask (None | list[torch.Tensor]): Point-wise instance
label of each batch.
bbox_preds (torch.Tensor): Bounding box predictions of ssd3d head.
Returns:
tuple[torch.Tensor]: Targets of ssd3d head.
"""
# find empty example
for index in range(len(gt_labels_3d)):
if len(gt_labels_3d[index]) == 0:
fake_box = gt_bboxes_3d[index].tensor.new_zeros(
1, gt_bboxes_3d[index].tensor.shape[-1])
gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box)
gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1)
if pts_semantic_mask is None:
pts_semantic_mask = [None for i in range(len(gt_labels_3d))]
pts_instance_mask = [None for i in range(len(gt_labels_3d))]
aggregated_points = [
bbox_preds['aggregated_points'][i]
for i in range(len(gt_labels_3d))
]
seed_points = [
bbox_preds['seed_points'][i, :self.num_candidates].detach()
for i in range(len(gt_labels_3d))
]
(vote_targets, center_targets, size_res_targets, dir_class_targets,
dir_res_targets, mask_targets, centerness_targets, corner3d_targets,
vote_mask, positive_mask, negative_mask) = multi_apply(
self.get_targets_single, points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask, aggregated_points,
seed_points)
center_targets = torch.stack(center_targets)
positive_mask = torch.stack(positive_mask)
negative_mask = torch.stack(negative_mask)
dir_class_targets = torch.stack(dir_class_targets)
dir_res_targets = torch.stack(dir_res_targets)
size_res_targets = torch.stack(size_res_targets)
mask_targets = torch.stack(mask_targets)
centerness_targets = torch.stack(centerness_targets).detach()
corner3d_targets = torch.stack(corner3d_targets)
vote_targets = torch.stack(vote_targets)
vote_mask = torch.stack(vote_mask)
center_targets -= bbox_preds['aggregated_points']
centerness_weights = (positive_mask +
negative_mask).unsqueeze(-1).repeat(
1, 1, self.num_classes).float()
centerness_weights = centerness_weights / \
(centerness_weights.sum() + 1e-6)
vote_mask = vote_mask / (vote_mask.sum() + 1e-6)
box_loss_weights = positive_mask / (positive_mask.sum() + 1e-6)
batch_size, proposal_num = dir_class_targets.shape[:2]
heading_label_one_hot = dir_class_targets.new_zeros(
(batch_size, proposal_num, self.num_dir_bins))
heading_label_one_hot.scatter_(2, dir_class_targets.unsqueeze(-1), 1)
heading_res_loss_weight = heading_label_one_hot * \
box_loss_weights.unsqueeze(-1)
return (vote_targets, center_targets, size_res_targets,
dir_class_targets, dir_res_targets, mask_targets,
centerness_targets, corner3d_targets, vote_mask, positive_mask,
negative_mask, centerness_weights, box_loss_weights,
heading_res_loss_weight)
[docs] def get_targets_single(self,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
aggregated_points=None,
seed_points=None):
"""Generate targets of ssd3d head for single batch.
Args:
points (torch.Tensor): Points of each batch.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth \
boxes of each batch.
gt_labels_3d (torch.Tensor): Labels of each batch.
pts_semantic_mask (None | torch.Tensor): Point-wise semantic
label of each batch.
pts_instance_mask (None | torch.Tensor): Point-wise instance
label of each batch.
aggregated_points (torch.Tensor): Aggregated points from
candidate points layer.
seed_points (torch.Tensor): Seed points of candidate points.
Returns:
tuple[torch.Tensor]: Targets of ssd3d head.
"""
assert self.bbox_coder.with_rot or pts_semantic_mask is not None
gt_bboxes_3d = gt_bboxes_3d.to(points.device)
valid_gt = gt_labels_3d != -1
gt_bboxes_3d = gt_bboxes_3d[valid_gt]
gt_labels_3d = gt_labels_3d[valid_gt]
# Generate fake GT for empty scene
if valid_gt.sum() == 0:
vote_targets = points.new_zeros(self.num_candidates, 3)
center_targets = points.new_zeros(self.num_candidates, 3)
size_res_targets = points.new_zeros(self.num_candidates, 3)
dir_class_targets = points.new_zeros(
self.num_candidates, dtype=torch.int64)
dir_res_targets = points.new_zeros(self.num_candidates)
mask_targets = points.new_zeros(
self.num_candidates, dtype=torch.int64)
centerness_targets = points.new_zeros(self.num_candidates,
self.num_classes)
corner3d_targets = points.new_zeros(self.num_candidates, 8, 3)
vote_mask = points.new_zeros(self.num_candidates, dtype=torch.bool)
positive_mask = points.new_zeros(
self.num_candidates, dtype=torch.bool)
negative_mask = points.new_ones(
self.num_candidates, dtype=torch.bool)
return (vote_targets, center_targets, size_res_targets,
dir_class_targets, dir_res_targets, mask_targets,
centerness_targets, corner3d_targets, vote_mask,
positive_mask, negative_mask)
gt_corner3d = gt_bboxes_3d.corners
(center_targets, size_targets, dir_class_targets,
dir_res_targets) = self.bbox_coder.encode(gt_bboxes_3d, gt_labels_3d)
points_mask, assignment = self._assign_targets_by_points_inside(
gt_bboxes_3d, aggregated_points)
center_targets = center_targets[assignment]
size_res_targets = size_targets[assignment]
mask_targets = gt_labels_3d[assignment]
dir_class_targets = dir_class_targets[assignment]
dir_res_targets = dir_res_targets[assignment]
corner3d_targets = gt_corner3d[assignment]
top_center_targets = center_targets.clone()
top_center_targets[:, 2] += size_res_targets[:, 2]
dist = torch.norm(aggregated_points - top_center_targets, dim=1)
dist_mask = dist < self.train_cfg.pos_distance_thr
positive_mask = (points_mask.max(1)[0] > 0) * dist_mask
negative_mask = (points_mask.max(1)[0] == 0)
# Centerness loss targets
canonical_xyz = aggregated_points - center_targets
if self.bbox_coder.with_rot:
# TODO: Align points rotation implementation of
# LiDARInstance3DBoxes and DepthInstance3DBoxes
canonical_xyz = rotation_3d_in_axis(
canonical_xyz.unsqueeze(0).transpose(0, 1),
-gt_bboxes_3d.yaw[assignment], 2).squeeze(1)
distance_front = torch.clamp(
size_res_targets[:, 0] - canonical_xyz[:, 0], min=0)
distance_back = torch.clamp(
size_res_targets[:, 0] + canonical_xyz[:, 0], min=0)
distance_left = torch.clamp(
size_res_targets[:, 1] - canonical_xyz[:, 1], min=0)
distance_right = torch.clamp(
size_res_targets[:, 1] + canonical_xyz[:, 1], min=0)
distance_top = torch.clamp(
size_res_targets[:, 2] - canonical_xyz[:, 2], min=0)
distance_bottom = torch.clamp(
size_res_targets[:, 2] + canonical_xyz[:, 2], min=0)
centerness_l = torch.min(distance_front, distance_back) / torch.max(
distance_front, distance_back)
centerness_w = torch.min(distance_left, distance_right) / torch.max(
distance_left, distance_right)
centerness_h = torch.min(distance_bottom, distance_top) / torch.max(
distance_bottom, distance_top)
centerness_targets = torch.clamp(
centerness_l * centerness_w * centerness_h, min=0)
centerness_targets = centerness_targets.pow(1 / 3.0)
centerness_targets = torch.clamp(centerness_targets, min=0, max=1)
proposal_num = centerness_targets.shape[0]
one_hot_centerness_targets = centerness_targets.new_zeros(
(proposal_num, self.num_classes))
one_hot_centerness_targets.scatter_(1, mask_targets.unsqueeze(-1), 1)
centerness_targets = centerness_targets.unsqueeze(
1) * one_hot_centerness_targets
# Vote loss targets
enlarged_gt_bboxes_3d = gt_bboxes_3d.enlarged_box(
self.train_cfg.expand_dims_length)
enlarged_gt_bboxes_3d.tensor[:, 2] -= self.train_cfg.expand_dims_length
vote_mask, vote_assignment = self._assign_targets_by_points_inside(
enlarged_gt_bboxes_3d, seed_points)
vote_targets = gt_bboxes_3d.gravity_center
vote_targets = vote_targets[vote_assignment] - seed_points
vote_mask = vote_mask.max(1)[0] > 0
return (vote_targets, center_targets, size_res_targets,
dir_class_targets, dir_res_targets, mask_targets,
centerness_targets, corner3d_targets, vote_mask, positive_mask,
negative_mask)
[docs] def get_bboxes(self, points, bbox_preds, input_metas, rescale=False):
"""Generate bboxes from sdd3d head predictions.
Args:
points (torch.Tensor): Input points.
bbox_preds (dict): Predictions from sdd3d head.
input_metas (list[dict]): Point cloud and image's meta info.
rescale (bool): Whether to rescale bboxes.
Returns:
list[tuple[torch.Tensor]]: Bounding boxes, scores and labels.
"""
# decode boxes
sem_scores = F.sigmoid(bbox_preds['obj_scores']).transpose(1, 2)
obj_scores = sem_scores.max(-1)[0]
bbox3d = self.bbox_coder.decode(bbox_preds)
batch_size = bbox3d.shape[0]
results = list()
for b in range(batch_size):
bbox_selected, score_selected, labels = self.multiclass_nms_single(
obj_scores[b], sem_scores[b], bbox3d[b], points[b, ..., :3],
input_metas[b])
# fix the wrong direction
# To do: remove this ops
bbox_selected[..., 6] += np.pi
bbox = input_metas[b]['box_type_3d'](
bbox_selected.clone(),
box_dim=bbox_selected.shape[-1],
with_yaw=self.bbox_coder.with_rot)
results.append((bbox, score_selected, labels))
return results
[docs] def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points,
input_meta):
"""Multi-class nms in single batch.
Args:
obj_scores (torch.Tensor): Objectness score of bounding boxes.
sem_scores (torch.Tensor): semantic class score of bounding boxes.
bbox (torch.Tensor): Predicted bounding boxes.
points (torch.Tensor): Input points.
input_meta (dict): Point cloud and image's meta info.
Returns:
tuple[torch.Tensor]: Bounding boxes, scores and labels.
"""
num_bbox = bbox.shape[0]
bbox = input_meta['box_type_3d'](
bbox.clone(),
box_dim=bbox.shape[-1],
with_yaw=self.bbox_coder.with_rot,
origin=(0.5, 0.5, 0.5))
if isinstance(bbox, LiDARInstance3DBoxes):
box_idx = bbox.points_in_boxes(points)
box_indices = box_idx.new_zeros([num_bbox + 1])
box_idx[box_idx == -1] = num_bbox
box_indices.scatter_add_(0, box_idx.long(),
box_idx.new_ones(box_idx.shape))
box_indices = box_indices[:-1]
nonempty_box_mask = box_indices >= 0
elif isinstance(bbox, DepthInstance3DBoxes):
box_indices = bbox.points_in_boxes(points)
nonempty_box_mask = box_indices.T.sum(1) >= 0
else:
raise NotImplementedError('Unsupported bbox type!')
corner3d = bbox.corners
minmax_box3d = corner3d.new(torch.Size((corner3d.shape[0], 6)))
minmax_box3d[:, :3] = torch.min(corner3d, dim=1)[0]
minmax_box3d[:, 3:] = torch.max(corner3d, dim=1)[0]
bbox_classes = torch.argmax(sem_scores, -1)
nms_selected = batched_nms(
minmax_box3d[nonempty_box_mask][:, [0, 1, 3, 4]],
obj_scores[nonempty_box_mask], bbox_classes[nonempty_box_mask],
self.test_cfg.nms_cfg)[1]
if nms_selected.shape[0] > self.test_cfg.max_output_num:
nms_selected = nms_selected[:self.test_cfg.max_output_num]
# filter empty boxes and boxes with low score
scores_mask = (obj_scores >= self.test_cfg.score_thr)
nonempty_box_inds = torch.nonzero(
nonempty_box_mask, as_tuple=False).flatten()
nonempty_mask = torch.zeros_like(bbox_classes).scatter(
0, nonempty_box_inds[nms_selected], 1)
selected = (nonempty_mask.bool() & scores_mask.bool())
if self.test_cfg.per_class_proposal:
bbox_selected, score_selected, labels = [], [], []
for k in range(sem_scores.shape[-1]):
bbox_selected.append(bbox[selected].tensor)
score_selected.append(obj_scores[selected])
labels.append(
torch.zeros_like(bbox_classes[selected]).fill_(k))
bbox_selected = torch.cat(bbox_selected, 0)
score_selected = torch.cat(score_selected, 0)
labels = torch.cat(labels, 0)
else:
bbox_selected = bbox[selected].tensor
score_selected = obj_scores[selected]
labels = bbox_classes[selected]
return bbox_selected, score_selected, labels
def _assign_targets_by_points_inside(self, bboxes_3d, points):
"""Compute assignment by checking whether point is inside bbox.
Args:
bboxes_3d (BaseInstance3DBoxes): Instance of bounding boxes.
points (torch.Tensor): Points of a batch.
Returns:
tuple[torch.Tensor]: Flags indicating whether each point is
inside bbox and the index of box where each point are in.
"""
# TODO: align points_in_boxes function in each box_structures
num_bbox = bboxes_3d.tensor.shape[0]
if isinstance(bboxes_3d, LiDARInstance3DBoxes):
assignment = bboxes_3d.points_in_boxes(points).long()
points_mask = assignment.new_zeros(
[assignment.shape[0], num_bbox + 1])
assignment[assignment == -1] = num_bbox
points_mask.scatter_(1, assignment.unsqueeze(1), 1)
points_mask = points_mask[:, :-1]
assignment[assignment == num_bbox] = num_bbox - 1
elif isinstance(bboxes_3d, DepthInstance3DBoxes):
points_mask = bboxes_3d.points_in_boxes(points)
assignment = points_mask.argmax(dim=-1)
else:
raise NotImplementedError('Unsupported bbox type!')
return points_mask, assignment