Shortcuts

Source code for mmdet3d.models.dense_heads.ssd_3d_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.ops.nms import batched_nms
from mmcv.runner import force_fp32
from torch.nn import functional as F

from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes,
                                          LiDARInstance3DBoxes,
                                          rotation_3d_in_axis)
from mmdet.core import multi_apply
from ..builder import HEADS, build_loss
from .vote_head import VoteHead


[docs]@HEADS.register_module() class SSD3DHead(VoteHead): r"""Bbox head of `3DSSD <https://arxiv.org/abs/2002.10187>`_. Args: num_classes (int): The number of class. bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and decoding boxes. in_channels (int): The number of input feature channel. train_cfg (dict): Config for training. test_cfg (dict): Config for testing. vote_module_cfg (dict): Config of VoteModule for point-wise votes. vote_aggregation_cfg (dict): Config of vote aggregation layer. pred_layer_cfg (dict): Config of classfication and regression prediction layers. conv_cfg (dict): Config of convolution in prediction layer. norm_cfg (dict): Config of BN in prediction layer. act_cfg (dict): Config of activation in prediction layer. objectness_loss (dict): Config of objectness loss. center_loss (dict): Config of center loss. dir_class_loss (dict): Config of direction classification loss. dir_res_loss (dict): Config of direction residual regression loss. size_res_loss (dict): Config of size residual regression loss. corner_loss (dict): Config of bbox corners regression loss. vote_loss (dict): Config of candidate points regression loss. """ def __init__(self, num_classes, bbox_coder, in_channels=256, train_cfg=None, test_cfg=None, vote_module_cfg=None, vote_aggregation_cfg=None, pred_layer_cfg=None, conv_cfg=dict(type='Conv1d'), norm_cfg=dict(type='BN1d'), act_cfg=dict(type='ReLU'), objectness_loss=None, center_loss=None, dir_class_loss=None, dir_res_loss=None, size_res_loss=None, corner_loss=None, vote_loss=None, init_cfg=None): super(SSD3DHead, self).__init__( num_classes, bbox_coder, train_cfg=train_cfg, test_cfg=test_cfg, vote_module_cfg=vote_module_cfg, vote_aggregation_cfg=vote_aggregation_cfg, pred_layer_cfg=pred_layer_cfg, conv_cfg=conv_cfg, norm_cfg=norm_cfg, objectness_loss=objectness_loss, center_loss=center_loss, dir_class_loss=dir_class_loss, dir_res_loss=dir_res_loss, size_class_loss=None, size_res_loss=size_res_loss, semantic_loss=None, init_cfg=init_cfg) self.corner_loss = build_loss(corner_loss) self.vote_loss = build_loss(vote_loss) self.num_candidates = vote_module_cfg['num_points'] def _get_cls_out_channels(self): """Return the channel number of classification outputs.""" # Class numbers (k) + objectness (1) return self.num_classes def _get_reg_out_channels(self): """Return the channel number of regression outputs.""" # Bbox classification and regression # (center residual (3), size regression (3) # heading class+residual (num_dir_bins*2)), return 3 + 3 + self.num_dir_bins * 2 def _extract_input(self, feat_dict): """Extract inputs from features dictionary. Args: feat_dict (dict): Feature dict from backbone. Returns: torch.Tensor: Coordinates of input points. torch.Tensor: Features of input points. torch.Tensor: Indices of input points. """ seed_points = feat_dict['sa_xyz'][-1] seed_features = feat_dict['sa_features'][-1] seed_indices = feat_dict['sa_indices'][-1] return seed_points, seed_features, seed_indices
[docs] @force_fp32(apply_to=('bbox_preds', )) def loss(self, bbox_preds, points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask=None, pts_instance_mask=None, img_metas=None, gt_bboxes_ignore=None): """Compute loss. Args: bbox_preds (dict): Predictions from forward of SSD3DHead. points (list[torch.Tensor]): Input points. gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth bboxes of each sample. gt_labels_3d (list[torch.Tensor]): Labels of each sample. pts_semantic_mask (list[torch.Tensor]): Point-wise semantic mask. pts_instance_mask (list[torch.Tensor]): Point-wise instance mask. img_metas (list[dict]): Contain pcd and img's meta info. gt_bboxes_ignore (list[torch.Tensor]): Specify which bounding. Returns: dict: Losses of 3DSSD. """ targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask, pts_instance_mask, bbox_preds) (vote_targets, center_targets, size_res_targets, dir_class_targets, dir_res_targets, mask_targets, centerness_targets, corner3d_targets, vote_mask, positive_mask, negative_mask, centerness_weights, box_loss_weights, heading_res_loss_weight) = targets # calculate centerness loss centerness_loss = self.objectness_loss( bbox_preds['obj_scores'].transpose(2, 1), centerness_targets, weight=centerness_weights) # calculate center loss center_loss = self.center_loss( bbox_preds['center_offset'], center_targets, weight=box_loss_weights.unsqueeze(-1)) # calculate direction class loss dir_class_loss = self.dir_class_loss( bbox_preds['dir_class'].transpose(1, 2), dir_class_targets, weight=box_loss_weights) # calculate direction residual loss dir_res_loss = self.dir_res_loss( bbox_preds['dir_res_norm'], dir_res_targets.unsqueeze(-1).repeat(1, 1, self.num_dir_bins), weight=heading_res_loss_weight) # calculate size residual loss size_loss = self.size_res_loss( bbox_preds['size'], size_res_targets, weight=box_loss_weights.unsqueeze(-1)) # calculate corner loss one_hot_dir_class_targets = dir_class_targets.new_zeros( bbox_preds['dir_class'].shape) one_hot_dir_class_targets.scatter_(2, dir_class_targets.unsqueeze(-1), 1) pred_bbox3d = self.bbox_coder.decode( dict( center=bbox_preds['center'], dir_res=bbox_preds['dir_res'], dir_class=one_hot_dir_class_targets, size=bbox_preds['size'])) pred_bbox3d = pred_bbox3d.reshape(-1, pred_bbox3d.shape[-1]) pred_bbox3d = img_metas[0]['box_type_3d']( pred_bbox3d.clone(), box_dim=pred_bbox3d.shape[-1], with_yaw=self.bbox_coder.with_rot, origin=(0.5, 0.5, 0.5)) pred_corners3d = pred_bbox3d.corners.reshape(-1, 8, 3) corner_loss = self.corner_loss( pred_corners3d, corner3d_targets.reshape(-1, 8, 3), weight=box_loss_weights.view(-1, 1, 1)) # calculate vote loss vote_loss = self.vote_loss( bbox_preds['vote_offset'].transpose(1, 2), vote_targets, weight=vote_mask.unsqueeze(-1)) losses = dict( centerness_loss=centerness_loss, center_loss=center_loss, dir_class_loss=dir_class_loss, dir_res_loss=dir_res_loss, size_res_loss=size_loss, corner_loss=corner_loss, vote_loss=vote_loss) return losses
[docs] def get_targets(self, points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask=None, pts_instance_mask=None, bbox_preds=None): """Generate targets of ssd3d head. Args: points (list[torch.Tensor]): Points of each batch. gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth bboxes of each batch. gt_labels_3d (list[torch.Tensor]): Labels of each batch. pts_semantic_mask (list[torch.Tensor]): Point-wise semantic label of each batch. pts_instance_mask (list[torch.Tensor]): Point-wise instance label of each batch. bbox_preds (torch.Tensor): Bounding box predictions of ssd3d head. Returns: tuple[torch.Tensor]: Targets of ssd3d head. """ # find empty example for index in range(len(gt_labels_3d)): if len(gt_labels_3d[index]) == 0: fake_box = gt_bboxes_3d[index].tensor.new_zeros( 1, gt_bboxes_3d[index].tensor.shape[-1]) gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box) gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1) if pts_semantic_mask is None: pts_semantic_mask = [None for i in range(len(gt_labels_3d))] pts_instance_mask = [None for i in range(len(gt_labels_3d))] aggregated_points = [ bbox_preds['aggregated_points'][i] for i in range(len(gt_labels_3d)) ] seed_points = [ bbox_preds['seed_points'][i, :self.num_candidates].detach() for i in range(len(gt_labels_3d)) ] (vote_targets, center_targets, size_res_targets, dir_class_targets, dir_res_targets, mask_targets, centerness_targets, corner3d_targets, vote_mask, positive_mask, negative_mask) = multi_apply( self.get_targets_single, points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask, pts_instance_mask, aggregated_points, seed_points) center_targets = torch.stack(center_targets) positive_mask = torch.stack(positive_mask) negative_mask = torch.stack(negative_mask) dir_class_targets = torch.stack(dir_class_targets) dir_res_targets = torch.stack(dir_res_targets) size_res_targets = torch.stack(size_res_targets) mask_targets = torch.stack(mask_targets) centerness_targets = torch.stack(centerness_targets).detach() corner3d_targets = torch.stack(corner3d_targets) vote_targets = torch.stack(vote_targets) vote_mask = torch.stack(vote_mask) center_targets -= bbox_preds['aggregated_points'] centerness_weights = (positive_mask + negative_mask).unsqueeze(-1).repeat( 1, 1, self.num_classes).float() centerness_weights = centerness_weights / \ (centerness_weights.sum() + 1e-6) vote_mask = vote_mask / (vote_mask.sum() + 1e-6) box_loss_weights = positive_mask / (positive_mask.sum() + 1e-6) batch_size, proposal_num = dir_class_targets.shape[:2] heading_label_one_hot = dir_class_targets.new_zeros( (batch_size, proposal_num, self.num_dir_bins)) heading_label_one_hot.scatter_(2, dir_class_targets.unsqueeze(-1), 1) heading_res_loss_weight = heading_label_one_hot * \ box_loss_weights.unsqueeze(-1) return (vote_targets, center_targets, size_res_targets, dir_class_targets, dir_res_targets, mask_targets, centerness_targets, corner3d_targets, vote_mask, positive_mask, negative_mask, centerness_weights, box_loss_weights, heading_res_loss_weight)
[docs] def get_targets_single(self, points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask=None, pts_instance_mask=None, aggregated_points=None, seed_points=None): """Generate targets of ssd3d head for single batch. Args: points (torch.Tensor): Points of each batch. gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth boxes of each batch. gt_labels_3d (torch.Tensor): Labels of each batch. pts_semantic_mask (torch.Tensor): Point-wise semantic label of each batch. pts_instance_mask (torch.Tensor): Point-wise instance label of each batch. aggregated_points (torch.Tensor): Aggregated points from candidate points layer. seed_points (torch.Tensor): Seed points of candidate points. Returns: tuple[torch.Tensor]: Targets of ssd3d head. """ assert self.bbox_coder.with_rot or pts_semantic_mask is not None gt_bboxes_3d = gt_bboxes_3d.to(points.device) valid_gt = gt_labels_3d != -1 gt_bboxes_3d = gt_bboxes_3d[valid_gt] gt_labels_3d = gt_labels_3d[valid_gt] # Generate fake GT for empty scene if valid_gt.sum() == 0: vote_targets = points.new_zeros(self.num_candidates, 3) center_targets = points.new_zeros(self.num_candidates, 3) size_res_targets = points.new_zeros(self.num_candidates, 3) dir_class_targets = points.new_zeros( self.num_candidates, dtype=torch.int64) dir_res_targets = points.new_zeros(self.num_candidates) mask_targets = points.new_zeros( self.num_candidates, dtype=torch.int64) centerness_targets = points.new_zeros(self.num_candidates, self.num_classes) corner3d_targets = points.new_zeros(self.num_candidates, 8, 3) vote_mask = points.new_zeros(self.num_candidates, dtype=torch.bool) positive_mask = points.new_zeros( self.num_candidates, dtype=torch.bool) negative_mask = points.new_ones( self.num_candidates, dtype=torch.bool) return (vote_targets, center_targets, size_res_targets, dir_class_targets, dir_res_targets, mask_targets, centerness_targets, corner3d_targets, vote_mask, positive_mask, negative_mask) gt_corner3d = gt_bboxes_3d.corners (center_targets, size_targets, dir_class_targets, dir_res_targets) = self.bbox_coder.encode(gt_bboxes_3d, gt_labels_3d) points_mask, assignment = self._assign_targets_by_points_inside( gt_bboxes_3d, aggregated_points) center_targets = center_targets[assignment] size_res_targets = size_targets[assignment] mask_targets = gt_labels_3d[assignment] dir_class_targets = dir_class_targets[assignment] dir_res_targets = dir_res_targets[assignment] corner3d_targets = gt_corner3d[assignment] top_center_targets = center_targets.clone() top_center_targets[:, 2] += size_res_targets[:, 2] dist = torch.norm(aggregated_points - top_center_targets, dim=1) dist_mask = dist < self.train_cfg.pos_distance_thr positive_mask = (points_mask.max(1)[0] > 0) * dist_mask negative_mask = (points_mask.max(1)[0] == 0) # Centerness loss targets canonical_xyz = aggregated_points - center_targets if self.bbox_coder.with_rot: # TODO: Align points rotation implementation of # LiDARInstance3DBoxes and DepthInstance3DBoxes canonical_xyz = rotation_3d_in_axis( canonical_xyz.unsqueeze(0).transpose(0, 1), -gt_bboxes_3d.yaw[assignment], axis=2).squeeze(1) distance_front = torch.clamp( size_res_targets[:, 0] - canonical_xyz[:, 0], min=0) distance_back = torch.clamp( size_res_targets[:, 0] + canonical_xyz[:, 0], min=0) distance_left = torch.clamp( size_res_targets[:, 1] - canonical_xyz[:, 1], min=0) distance_right = torch.clamp( size_res_targets[:, 1] + canonical_xyz[:, 1], min=0) distance_top = torch.clamp( size_res_targets[:, 2] - canonical_xyz[:, 2], min=0) distance_bottom = torch.clamp( size_res_targets[:, 2] + canonical_xyz[:, 2], min=0) centerness_l = torch.min(distance_front, distance_back) / torch.max( distance_front, distance_back) centerness_w = torch.min(distance_left, distance_right) / torch.max( distance_left, distance_right) centerness_h = torch.min(distance_bottom, distance_top) / torch.max( distance_bottom, distance_top) centerness_targets = torch.clamp( centerness_l * centerness_w * centerness_h, min=0) centerness_targets = centerness_targets.pow(1 / 3.0) centerness_targets = torch.clamp(centerness_targets, min=0, max=1) proposal_num = centerness_targets.shape[0] one_hot_centerness_targets = centerness_targets.new_zeros( (proposal_num, self.num_classes)) one_hot_centerness_targets.scatter_(1, mask_targets.unsqueeze(-1), 1) centerness_targets = centerness_targets.unsqueeze( 1) * one_hot_centerness_targets # Vote loss targets enlarged_gt_bboxes_3d = gt_bboxes_3d.enlarged_box( self.train_cfg.expand_dims_length) enlarged_gt_bboxes_3d.tensor[:, 2] -= self.train_cfg.expand_dims_length vote_mask, vote_assignment = self._assign_targets_by_points_inside( enlarged_gt_bboxes_3d, seed_points) vote_targets = gt_bboxes_3d.gravity_center vote_targets = vote_targets[vote_assignment] - seed_points vote_mask = vote_mask.max(1)[0] > 0 return (vote_targets, center_targets, size_res_targets, dir_class_targets, dir_res_targets, mask_targets, centerness_targets, corner3d_targets, vote_mask, positive_mask, negative_mask)
[docs] def get_bboxes(self, points, bbox_preds, input_metas, rescale=False): """Generate bboxes from 3DSSD head predictions. Args: points (torch.Tensor): Input points. bbox_preds (dict): Predictions from sdd3d head. input_metas (list[dict]): Point cloud and image's meta info. rescale (bool): Whether to rescale bboxes. Returns: list[tuple[torch.Tensor]]: Bounding boxes, scores and labels. """ # decode boxes sem_scores = F.sigmoid(bbox_preds['obj_scores']).transpose(1, 2) obj_scores = sem_scores.max(-1)[0] bbox3d = self.bbox_coder.decode(bbox_preds) batch_size = bbox3d.shape[0] results = list() for b in range(batch_size): bbox_selected, score_selected, labels = self.multiclass_nms_single( obj_scores[b], sem_scores[b], bbox3d[b], points[b, ..., :3], input_metas[b]) bbox = input_metas[b]['box_type_3d']( bbox_selected.clone(), box_dim=bbox_selected.shape[-1], with_yaw=self.bbox_coder.with_rot) results.append((bbox, score_selected, labels)) return results
[docs] def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points, input_meta): """Multi-class nms in single batch. Args: obj_scores (torch.Tensor): Objectness score of bounding boxes. sem_scores (torch.Tensor): Semantic class score of bounding boxes. bbox (torch.Tensor): Predicted bounding boxes. points (torch.Tensor): Input points. input_meta (dict): Point cloud and image's meta info. Returns: tuple[torch.Tensor]: Bounding boxes, scores and labels. """ bbox = input_meta['box_type_3d']( bbox.clone(), box_dim=bbox.shape[-1], with_yaw=self.bbox_coder.with_rot, origin=(0.5, 0.5, 0.5)) if isinstance(bbox, (LiDARInstance3DBoxes, DepthInstance3DBoxes)): box_indices = bbox.points_in_boxes_all(points) nonempty_box_mask = box_indices.T.sum(1) >= 0 else: raise NotImplementedError('Unsupported bbox type!') corner3d = bbox.corners minmax_box3d = corner3d.new(torch.Size((corner3d.shape[0], 6))) minmax_box3d[:, :3] = torch.min(corner3d, dim=1)[0] minmax_box3d[:, 3:] = torch.max(corner3d, dim=1)[0] bbox_classes = torch.argmax(sem_scores, -1) nms_keep = batched_nms( minmax_box3d[nonempty_box_mask][:, [0, 1, 3, 4]], obj_scores[nonempty_box_mask], bbox_classes[nonempty_box_mask], self.test_cfg.nms_cfg)[1] if nms_keep.shape[0] > self.test_cfg.max_output_num: nms_keep = nms_keep[:self.test_cfg.max_output_num] # filter empty boxes and boxes with low score scores_mask = (obj_scores >= self.test_cfg.score_thr) nonempty_box_inds = torch.nonzero( nonempty_box_mask, as_tuple=False).flatten() nonempty_mask = torch.zeros_like(bbox_classes).scatter( 0, nonempty_box_inds[nms_keep], 1) selected = (nonempty_mask.bool() & scores_mask.bool()) if self.test_cfg.per_class_proposal: bbox_selected, score_selected, labels = [], [], [] for k in range(sem_scores.shape[-1]): bbox_selected.append(bbox[selected].tensor) score_selected.append(obj_scores[selected]) labels.append( torch.zeros_like(bbox_classes[selected]).fill_(k)) bbox_selected = torch.cat(bbox_selected, 0) score_selected = torch.cat(score_selected, 0) labels = torch.cat(labels, 0) else: bbox_selected = bbox[selected].tensor score_selected = obj_scores[selected] labels = bbox_classes[selected] return bbox_selected, score_selected, labels
def _assign_targets_by_points_inside(self, bboxes_3d, points): """Compute assignment by checking whether point is inside bbox. Args: bboxes_3d (BaseInstance3DBoxes): Instance of bounding boxes. points (torch.Tensor): Points of a batch. Returns: tuple[torch.Tensor]: Flags indicating whether each point is inside bbox and the index of box where each point are in. """ if isinstance(bboxes_3d, (LiDARInstance3DBoxes, DepthInstance3DBoxes)): points_mask = bboxes_3d.points_in_boxes_all(points) assignment = points_mask.argmax(dim=-1) else: raise NotImplementedError('Unsupported bbox type!') return points_mask, assignment
Read the Docs v: dev
Versions
latest
stable
v1.0.0rc1
v1.0.0rc0
v0.18.1
v0.18.0
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.