Shortcuts

Source code for mmdet3d.models.dense_heads.vote_head

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from mmcv.ops import furthest_point_sample
from mmcv.runner import BaseModule, force_fp32
from torch.nn import functional as F

from mmdet3d.core.post_processing import aligned_3d_nms
from mmdet3d.models.losses import chamfer_distance
from mmdet3d.models.model_utils import VoteModule
from mmdet3d.ops import build_sa_module
from mmdet.core import build_bbox_coder, multi_apply
from ..builder import HEADS, build_loss
from .base_conv_bbox_head import BaseConvBboxHead


[docs]@HEADS.register_module() class VoteHead(BaseModule): r"""Bbox head of `Votenet <https://arxiv.org/abs/1904.09664>`_. Args: num_classes (int): The number of class. bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and decoding boxes. train_cfg (dict): Config for training. test_cfg (dict): Config for testing. vote_module_cfg (dict): Config of VoteModule for point-wise votes. vote_aggregation_cfg (dict): Config of vote aggregation layer. pred_layer_cfg (dict): Config of classfication and regression prediction layers. conv_cfg (dict): Config of convolution in prediction layer. norm_cfg (dict): Config of BN in prediction layer. objectness_loss (dict): Config of objectness loss. center_loss (dict): Config of center loss. dir_class_loss (dict): Config of direction classification loss. dir_res_loss (dict): Config of direction residual regression loss. size_class_loss (dict): Config of size classification loss. size_res_loss (dict): Config of size residual regression loss. semantic_loss (dict): Config of point-wise semantic segmentation loss. """ def __init__(self, num_classes, bbox_coder, train_cfg=None, test_cfg=None, vote_module_cfg=None, vote_aggregation_cfg=None, pred_layer_cfg=None, conv_cfg=dict(type='Conv1d'), norm_cfg=dict(type='BN1d'), objectness_loss=None, center_loss=None, dir_class_loss=None, dir_res_loss=None, size_class_loss=None, size_res_loss=None, semantic_loss=None, iou_loss=None, init_cfg=None): super(VoteHead, self).__init__(init_cfg=init_cfg) self.num_classes = num_classes self.train_cfg = train_cfg self.test_cfg = test_cfg self.gt_per_seed = vote_module_cfg['gt_per_seed'] self.num_proposal = vote_aggregation_cfg['num_point'] self.objectness_loss = build_loss(objectness_loss) self.center_loss = build_loss(center_loss) self.dir_res_loss = build_loss(dir_res_loss) self.dir_class_loss = build_loss(dir_class_loss) self.size_res_loss = build_loss(size_res_loss) if size_class_loss is not None: self.size_class_loss = build_loss(size_class_loss) if semantic_loss is not None: self.semantic_loss = build_loss(semantic_loss) if iou_loss is not None: self.iou_loss = build_loss(iou_loss) else: self.iou_loss = None self.bbox_coder = build_bbox_coder(bbox_coder) self.num_sizes = self.bbox_coder.num_sizes self.num_dir_bins = self.bbox_coder.num_dir_bins self.vote_module = VoteModule(**vote_module_cfg) self.vote_aggregation = build_sa_module(vote_aggregation_cfg) self.fp16_enabled = False # Bbox classification and regression self.conv_pred = BaseConvBboxHead( **pred_layer_cfg, num_cls_out_channels=self._get_cls_out_channels(), num_reg_out_channels=self._get_reg_out_channels()) def _get_cls_out_channels(self): """Return the channel number of classification outputs.""" # Class numbers (k) + objectness (2) return self.num_classes + 2 def _get_reg_out_channels(self): """Return the channel number of regression outputs.""" # Objectness scores (2), center residual (3), # heading class+residual (num_dir_bins*2), # size class+residual(num_sizes*4) return 3 + self.num_dir_bins * 2 + self.num_sizes * 4 def _extract_input(self, feat_dict): """Extract inputs from features dictionary. Args: feat_dict (dict): Feature dict from backbone. Returns: torch.Tensor: Coordinates of input points. torch.Tensor: Features of input points. torch.Tensor: Indices of input points. """ # for imvotenet if 'seed_points' in feat_dict and \ 'seed_features' in feat_dict and \ 'seed_indices' in feat_dict: seed_points = feat_dict['seed_points'] seed_features = feat_dict['seed_features'] seed_indices = feat_dict['seed_indices'] # for votenet else: seed_points = feat_dict['fp_xyz'][-1] seed_features = feat_dict['fp_features'][-1] seed_indices = feat_dict['fp_indices'][-1] return seed_points, seed_features, seed_indices
[docs] def forward(self, feat_dict, sample_mod): """Forward pass. Note: The forward of VoteHead is divided into 4 steps: 1. Generate vote_points from seed_points. 2. Aggregate vote_points. 3. Predict bbox and score. 4. Decode predictions. Args: feat_dict (dict): Feature dict from backbone. sample_mod (str): Sample mode for vote aggregation layer. valid modes are "vote", "seed", "random" and "spec". Returns: dict: Predictions of vote head. """ assert sample_mod in ['vote', 'seed', 'random', 'spec'] seed_points, seed_features, seed_indices = self._extract_input( feat_dict) # 1. generate vote_points from seed_points vote_points, vote_features, vote_offset = self.vote_module( seed_points, seed_features) results = dict( seed_points=seed_points, seed_indices=seed_indices, vote_points=vote_points, vote_features=vote_features, vote_offset=vote_offset) # 2. aggregate vote_points if sample_mod == 'vote': # use fps in vote_aggregation aggregation_inputs = dict( points_xyz=vote_points, features=vote_features) elif sample_mod == 'seed': # FPS on seed and choose the votes corresponding to the seeds sample_indices = furthest_point_sample(seed_points, self.num_proposal) aggregation_inputs = dict( points_xyz=vote_points, features=vote_features, indices=sample_indices) elif sample_mod == 'random': # Random sampling from the votes batch_size, num_seed = seed_points.shape[:2] sample_indices = seed_points.new_tensor( torch.randint(0, num_seed, (batch_size, self.num_proposal)), dtype=torch.int32) aggregation_inputs = dict( points_xyz=vote_points, features=vote_features, indices=sample_indices) elif sample_mod == 'spec': # Specify the new center in vote_aggregation aggregation_inputs = dict( points_xyz=seed_points, features=seed_features, target_xyz=vote_points) else: raise NotImplementedError( f'Sample mode {sample_mod} is not supported!') vote_aggregation_ret = self.vote_aggregation(**aggregation_inputs) aggregated_points, features, aggregated_indices = vote_aggregation_ret results['aggregated_points'] = aggregated_points results['aggregated_features'] = features results['aggregated_indices'] = aggregated_indices # 3. predict bbox and score cls_predictions, reg_predictions = self.conv_pred(features) # 4. decode predictions decode_res = self.bbox_coder.split_pred(cls_predictions, reg_predictions, aggregated_points) results.update(decode_res) return results
[docs] @force_fp32(apply_to=('bbox_preds', )) def loss(self, bbox_preds, points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask=None, pts_instance_mask=None, img_metas=None, gt_bboxes_ignore=None, ret_target=False): """Compute loss. Args: bbox_preds (dict): Predictions from forward of vote head. points (list[torch.Tensor]): Input points. gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth bboxes of each sample. gt_labels_3d (list[torch.Tensor]): Labels of each sample. pts_semantic_mask (list[torch.Tensor]): Point-wise semantic mask. pts_instance_mask (list[torch.Tensor]): Point-wise instance mask. img_metas (list[dict]): Contain pcd and img's meta info. gt_bboxes_ignore (list[torch.Tensor]): Specify which bounding. ret_target (Bool): Return targets or not. Returns: dict: Losses of Votenet. """ targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask, pts_instance_mask, bbox_preds) (vote_targets, vote_target_masks, size_class_targets, size_res_targets, dir_class_targets, dir_res_targets, center_targets, assigned_center_targets, mask_targets, valid_gt_masks, objectness_targets, objectness_weights, box_loss_weights, valid_gt_weights) = targets # calculate vote loss vote_loss = self.vote_module.get_loss(bbox_preds['seed_points'], bbox_preds['vote_points'], bbox_preds['seed_indices'], vote_target_masks, vote_targets) # calculate objectness loss objectness_loss = self.objectness_loss( bbox_preds['obj_scores'].transpose(2, 1), objectness_targets, weight=objectness_weights) # calculate center loss source2target_loss, target2source_loss = self.center_loss( bbox_preds['center'], center_targets, src_weight=box_loss_weights, dst_weight=valid_gt_weights) center_loss = source2target_loss + target2source_loss # calculate direction class loss dir_class_loss = self.dir_class_loss( bbox_preds['dir_class'].transpose(2, 1), dir_class_targets, weight=box_loss_weights) # calculate direction residual loss batch_size, proposal_num = size_class_targets.shape[:2] heading_label_one_hot = vote_targets.new_zeros( (batch_size, proposal_num, self.num_dir_bins)) heading_label_one_hot.scatter_(2, dir_class_targets.unsqueeze(-1), 1) dir_res_norm = torch.sum( bbox_preds['dir_res_norm'] * heading_label_one_hot, -1) dir_res_loss = self.dir_res_loss( dir_res_norm, dir_res_targets, weight=box_loss_weights) # calculate size class loss size_class_loss = self.size_class_loss( bbox_preds['size_class'].transpose(2, 1), size_class_targets, weight=box_loss_weights) # calculate size residual loss one_hot_size_targets = vote_targets.new_zeros( (batch_size, proposal_num, self.num_sizes)) one_hot_size_targets.scatter_(2, size_class_targets.unsqueeze(-1), 1) one_hot_size_targets_expand = one_hot_size_targets.unsqueeze( -1).repeat(1, 1, 1, 3).contiguous() size_residual_norm = torch.sum( bbox_preds['size_res_norm'] * one_hot_size_targets_expand, 2) box_loss_weights_expand = box_loss_weights.unsqueeze(-1).repeat( 1, 1, 3) size_res_loss = self.size_res_loss( size_residual_norm, size_res_targets, weight=box_loss_weights_expand) # calculate semantic loss semantic_loss = self.semantic_loss( bbox_preds['sem_scores'].transpose(2, 1), mask_targets, weight=box_loss_weights) losses = dict( vote_loss=vote_loss, objectness_loss=objectness_loss, semantic_loss=semantic_loss, center_loss=center_loss, dir_class_loss=dir_class_loss, dir_res_loss=dir_res_loss, size_class_loss=size_class_loss, size_res_loss=size_res_loss) if self.iou_loss: corners_pred = self.bbox_coder.decode_corners( bbox_preds['center'], size_residual_norm, one_hot_size_targets_expand) corners_target = self.bbox_coder.decode_corners( assigned_center_targets, size_res_targets, one_hot_size_targets_expand) iou_loss = self.iou_loss( corners_pred, corners_target, weight=box_loss_weights) losses['iou_loss'] = iou_loss if ret_target: losses['targets'] = targets return losses
[docs] def get_targets(self, points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask=None, pts_instance_mask=None, bbox_preds=None): """Generate targets of vote head. Args: points (list[torch.Tensor]): Points of each batch. gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth bboxes of each batch. gt_labels_3d (list[torch.Tensor]): Labels of each batch. pts_semantic_mask (list[torch.Tensor]): Point-wise semantic label of each batch. pts_instance_mask (list[torch.Tensor]): Point-wise instance label of each batch. bbox_preds (torch.Tensor): Bounding box predictions of vote head. Returns: tuple[torch.Tensor]: Targets of vote head. """ # find empty example valid_gt_masks = list() gt_num = list() for index in range(len(gt_labels_3d)): if len(gt_labels_3d[index]) == 0: fake_box = gt_bboxes_3d[index].tensor.new_zeros( 1, gt_bboxes_3d[index].tensor.shape[-1]) gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box) gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1) valid_gt_masks.append(gt_labels_3d[index].new_zeros(1)) gt_num.append(1) else: valid_gt_masks.append(gt_labels_3d[index].new_ones( gt_labels_3d[index].shape)) gt_num.append(gt_labels_3d[index].shape[0]) max_gt_num = max(gt_num) if pts_semantic_mask is None: pts_semantic_mask = [None for i in range(len(gt_labels_3d))] pts_instance_mask = [None for i in range(len(gt_labels_3d))] aggregated_points = [ bbox_preds['aggregated_points'][i] for i in range(len(gt_labels_3d)) ] (vote_targets, vote_target_masks, size_class_targets, size_res_targets, dir_class_targets, dir_res_targets, center_targets, assigned_center_targets, mask_targets, objectness_targets, objectness_masks) = multi_apply(self.get_targets_single, points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask, pts_instance_mask, aggregated_points) # pad targets as original code of votenet. for index in range(len(gt_labels_3d)): pad_num = max_gt_num - gt_labels_3d[index].shape[0] center_targets[index] = F.pad(center_targets[index], (0, 0, 0, pad_num)) valid_gt_masks[index] = F.pad(valid_gt_masks[index], (0, pad_num)) vote_targets = torch.stack(vote_targets) vote_target_masks = torch.stack(vote_target_masks) center_targets = torch.stack(center_targets) valid_gt_masks = torch.stack(valid_gt_masks) assigned_center_targets = torch.stack(assigned_center_targets) objectness_targets = torch.stack(objectness_targets) objectness_weights = torch.stack(objectness_masks) objectness_weights /= (torch.sum(objectness_weights) + 1e-6) box_loss_weights = objectness_targets.float() / ( torch.sum(objectness_targets).float() + 1e-6) valid_gt_weights = valid_gt_masks.float() / ( torch.sum(valid_gt_masks.float()) + 1e-6) dir_class_targets = torch.stack(dir_class_targets) dir_res_targets = torch.stack(dir_res_targets) size_class_targets = torch.stack(size_class_targets) size_res_targets = torch.stack(size_res_targets) mask_targets = torch.stack(mask_targets) return (vote_targets, vote_target_masks, size_class_targets, size_res_targets, dir_class_targets, dir_res_targets, center_targets, assigned_center_targets, mask_targets, valid_gt_masks, objectness_targets, objectness_weights, box_loss_weights, valid_gt_weights)
[docs] def get_targets_single(self, points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask=None, pts_instance_mask=None, aggregated_points=None): """Generate targets of vote head for single batch. Args: points (torch.Tensor): Points of each batch. gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth boxes of each batch. gt_labels_3d (torch.Tensor): Labels of each batch. pts_semantic_mask (torch.Tensor): Point-wise semantic label of each batch. pts_instance_mask (torch.Tensor): Point-wise instance label of each batch. aggregated_points (torch.Tensor): Aggregated points from vote aggregation layer. Returns: tuple[torch.Tensor]: Targets of vote head. """ assert self.bbox_coder.with_rot or pts_semantic_mask is not None gt_bboxes_3d = gt_bboxes_3d.to(points.device) # generate votes target num_points = points.shape[0] if self.bbox_coder.with_rot: vote_targets = points.new_zeros([num_points, 3 * self.gt_per_seed]) vote_target_masks = points.new_zeros([num_points], dtype=torch.long) vote_target_idx = points.new_zeros([num_points], dtype=torch.long) box_indices_all = gt_bboxes_3d.points_in_boxes_all(points) for i in range(gt_labels_3d.shape[0]): box_indices = box_indices_all[:, i] indices = torch.nonzero( box_indices, as_tuple=False).squeeze(-1) selected_points = points[indices] vote_target_masks[indices] = 1 vote_targets_tmp = vote_targets[indices] votes = gt_bboxes_3d.gravity_center[i].unsqueeze( 0) - selected_points[:, :3] for j in range(self.gt_per_seed): column_indices = torch.nonzero( vote_target_idx[indices] == j, as_tuple=False).squeeze(-1) vote_targets_tmp[column_indices, int(j * 3):int(j * 3 + 3)] = votes[column_indices] if j == 0: vote_targets_tmp[column_indices] = votes[ column_indices].repeat(1, self.gt_per_seed) vote_targets[indices] = vote_targets_tmp vote_target_idx[indices] = torch.clamp( vote_target_idx[indices] + 1, max=2) elif pts_semantic_mask is not None: vote_targets = points.new_zeros([num_points, 3]) vote_target_masks = points.new_zeros([num_points], dtype=torch.long) for i in torch.unique(pts_instance_mask): indices = torch.nonzero( pts_instance_mask == i, as_tuple=False).squeeze(-1) if pts_semantic_mask[indices[0]] < self.num_classes: selected_points = points[indices, :3] center = 0.5 * ( selected_points.min(0)[0] + selected_points.max(0)[0]) vote_targets[indices, :] = center - selected_points vote_target_masks[indices] = 1 vote_targets = vote_targets.repeat((1, self.gt_per_seed)) else: raise NotImplementedError (center_targets, size_class_targets, size_res_targets, dir_class_targets, dir_res_targets) = self.bbox_coder.encode(gt_bboxes_3d, gt_labels_3d) proposal_num = aggregated_points.shape[0] distance1, _, assignment, _ = chamfer_distance( aggregated_points.unsqueeze(0), center_targets.unsqueeze(0), reduction='none') assignment = assignment.squeeze(0) euclidean_distance1 = torch.sqrt(distance1.squeeze(0) + 1e-6) objectness_targets = points.new_zeros((proposal_num), dtype=torch.long) objectness_targets[ euclidean_distance1 < self.train_cfg['pos_distance_thr']] = 1 objectness_masks = points.new_zeros((proposal_num)) objectness_masks[ euclidean_distance1 < self.train_cfg['pos_distance_thr']] = 1.0 objectness_masks[ euclidean_distance1 > self.train_cfg['neg_distance_thr']] = 1.0 dir_class_targets = dir_class_targets[assignment] dir_res_targets = dir_res_targets[assignment] dir_res_targets /= (np.pi / self.num_dir_bins) size_class_targets = size_class_targets[assignment] size_res_targets = size_res_targets[assignment] one_hot_size_targets = gt_bboxes_3d.tensor.new_zeros( (proposal_num, self.num_sizes)) one_hot_size_targets.scatter_(1, size_class_targets.unsqueeze(-1), 1) one_hot_size_targets = one_hot_size_targets.unsqueeze(-1).repeat( 1, 1, 3) mean_sizes = size_res_targets.new_tensor( self.bbox_coder.mean_sizes).unsqueeze(0) pos_mean_sizes = torch.sum(one_hot_size_targets * mean_sizes, 1) size_res_targets /= pos_mean_sizes mask_targets = gt_labels_3d[assignment] assigned_center_targets = center_targets[assignment] return (vote_targets, vote_target_masks, size_class_targets, size_res_targets, dir_class_targets, dir_res_targets, center_targets, assigned_center_targets, mask_targets.long(), objectness_targets, objectness_masks)
[docs] def get_bboxes(self, points, bbox_preds, input_metas, rescale=False, use_nms=True): """Generate bboxes from vote head predictions. Args: points (torch.Tensor): Input points. bbox_preds (dict): Predictions from vote head. input_metas (list[dict]): Point cloud and image's meta info. rescale (bool): Whether to rescale bboxes. use_nms (bool): Whether to apply NMS, skip nms postprocessing while using vote head in rpn stage. Returns: list[tuple[torch.Tensor]]: Bounding boxes, scores and labels. """ # decode boxes obj_scores = F.softmax(bbox_preds['obj_scores'], dim=-1)[..., -1] sem_scores = F.softmax(bbox_preds['sem_scores'], dim=-1) bbox3d = self.bbox_coder.decode(bbox_preds) if use_nms: batch_size = bbox3d.shape[0] results = list() for b in range(batch_size): bbox_selected, score_selected, labels = \ self.multiclass_nms_single(obj_scores[b], sem_scores[b], bbox3d[b], points[b, ..., :3], input_metas[b]) bbox = input_metas[b]['box_type_3d']( bbox_selected, box_dim=bbox_selected.shape[-1], with_yaw=self.bbox_coder.with_rot) results.append((bbox, score_selected, labels)) return results else: return bbox3d
[docs] def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points, input_meta): """Multi-class nms in single batch. Args: obj_scores (torch.Tensor): Objectness score of bounding boxes. sem_scores (torch.Tensor): semantic class score of bounding boxes. bbox (torch.Tensor): Predicted bounding boxes. points (torch.Tensor): Input points. input_meta (dict): Point cloud and image's meta info. Returns: tuple[torch.Tensor]: Bounding boxes, scores and labels. """ bbox = input_meta['box_type_3d']( bbox, box_dim=bbox.shape[-1], with_yaw=self.bbox_coder.with_rot, origin=(0.5, 0.5, 0.5)) box_indices = bbox.points_in_boxes_all(points) corner3d = bbox.corners minmax_box3d = corner3d.new(torch.Size((corner3d.shape[0], 6))) minmax_box3d[:, :3] = torch.min(corner3d, dim=1)[0] minmax_box3d[:, 3:] = torch.max(corner3d, dim=1)[0] nonempty_box_mask = box_indices.T.sum(1) > 5 bbox_classes = torch.argmax(sem_scores, -1) nms_selected = aligned_3d_nms(minmax_box3d[nonempty_box_mask], obj_scores[nonempty_box_mask], bbox_classes[nonempty_box_mask], self.test_cfg.nms_thr) # filter empty boxes and boxes with low score scores_mask = (obj_scores > self.test_cfg.score_thr) nonempty_box_inds = torch.nonzero( nonempty_box_mask, as_tuple=False).flatten() nonempty_mask = torch.zeros_like(bbox_classes).scatter( 0, nonempty_box_inds[nms_selected], 1) selected = (nonempty_mask.bool() & scores_mask.bool()) if self.test_cfg.per_class_proposal: bbox_selected, score_selected, labels = [], [], [] for k in range(sem_scores.shape[-1]): bbox_selected.append(bbox[selected].tensor) score_selected.append(obj_scores[selected] * sem_scores[selected][:, k]) labels.append( torch.zeros_like(bbox_classes[selected]).fill_(k)) bbox_selected = torch.cat(bbox_selected, 0) score_selected = torch.cat(score_selected, 0) labels = torch.cat(labels, 0) else: bbox_selected = bbox[selected].tensor score_selected = obj_scores[selected] labels = bbox_classes[selected] return bbox_selected, score_selected, labels
Read the Docs v: dev
Versions
latest
stable
v1.0.0rc1
v1.0.0rc0
v0.18.1
v0.18.0
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.