Shortcuts

Source code for mmdet3d.models.roi_heads.bbox_heads.h3d_bbox_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from torch import nn as nn
from torch.nn import functional as F

from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet3d.core.post_processing import aligned_3d_nms
from mmdet3d.models.builder import HEADS, build_loss
from mmdet3d.models.losses import chamfer_distance
from mmdet3d.ops import build_sa_module
from mmdet.core import build_bbox_coder, multi_apply


[docs]@HEADS.register_module() class H3DBboxHead(BaseModule): r"""Bbox head of `H3DNet <https://arxiv.org/abs/2006.05682>`_. Args: num_classes (int): The number of classes. surface_matching_cfg (dict): Config for surface primitive matching. line_matching_cfg (dict): Config for line primitive matching. bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and decoding boxes. train_cfg (dict): Config for training. test_cfg (dict): Config for testing. gt_per_seed (int): Number of ground truth votes generated from each seed point. num_proposal (int): Number of proposal votes generated. feat_channels (tuple[int]): Convolution channels of prediction layer. primitive_feat_refine_streams (int): The number of mlps to refine primitive feature. primitive_refine_channels (tuple[int]): Convolution channels of prediction layer. upper_thresh (float): Threshold for line matching. surface_thresh (float): Threshold for surface matching. line_thresh (float): Threshold for line matching. conv_cfg (dict): Config of convolution in prediction layer. norm_cfg (dict): Config of BN in prediction layer. objectness_loss (dict): Config of objectness loss. center_loss (dict): Config of center loss. dir_class_loss (dict): Config of direction classification loss. dir_res_loss (dict): Config of direction residual regression loss. size_class_loss (dict): Config of size classification loss. size_res_loss (dict): Config of size residual regression loss. semantic_loss (dict): Config of point-wise semantic segmentation loss. cues_objectness_loss (dict): Config of cues objectness loss. cues_semantic_loss (dict): Config of cues semantic loss. proposal_objectness_loss (dict): Config of proposal objectness loss. primitive_center_loss (dict): Config of primitive center regression loss. """ def __init__(self, num_classes, suface_matching_cfg, line_matching_cfg, bbox_coder, train_cfg=None, test_cfg=None, gt_per_seed=1, num_proposal=256, feat_channels=(128, 128), primitive_feat_refine_streams=2, primitive_refine_channels=[128, 128, 128], upper_thresh=100.0, surface_thresh=0.5, line_thresh=0.5, conv_cfg=dict(type='Conv1d'), norm_cfg=dict(type='BN1d'), objectness_loss=None, center_loss=None, dir_class_loss=None, dir_res_loss=None, size_class_loss=None, size_res_loss=None, semantic_loss=None, cues_objectness_loss=None, cues_semantic_loss=None, proposal_objectness_loss=None, primitive_center_loss=None, init_cfg=None): super(H3DBboxHead, self).__init__(init_cfg=init_cfg) self.num_classes = num_classes self.train_cfg = train_cfg self.test_cfg = test_cfg self.gt_per_seed = gt_per_seed self.num_proposal = num_proposal self.with_angle = bbox_coder['with_rot'] self.upper_thresh = upper_thresh self.surface_thresh = surface_thresh self.line_thresh = line_thresh self.objectness_loss = build_loss(objectness_loss) self.center_loss = build_loss(center_loss) self.dir_class_loss = build_loss(dir_class_loss) self.dir_res_loss = build_loss(dir_res_loss) self.size_class_loss = build_loss(size_class_loss) self.size_res_loss = build_loss(size_res_loss) self.semantic_loss = build_loss(semantic_loss) self.bbox_coder = build_bbox_coder(bbox_coder) self.num_sizes = self.bbox_coder.num_sizes self.num_dir_bins = self.bbox_coder.num_dir_bins self.cues_objectness_loss = build_loss(cues_objectness_loss) self.cues_semantic_loss = build_loss(cues_semantic_loss) self.proposal_objectness_loss = build_loss(proposal_objectness_loss) self.primitive_center_loss = build_loss(primitive_center_loss) assert suface_matching_cfg['mlp_channels'][-1] == \ line_matching_cfg['mlp_channels'][-1] # surface center matching self.surface_center_matcher = build_sa_module(suface_matching_cfg) # line center matching self.line_center_matcher = build_sa_module(line_matching_cfg) # Compute the matching scores matching_feat_dims = suface_matching_cfg['mlp_channels'][-1] self.matching_conv = ConvModule( matching_feat_dims, matching_feat_dims, 1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, bias=True, inplace=True) self.matching_pred = nn.Conv1d(matching_feat_dims, 2, 1) # Compute the semantic matching scores self.semantic_matching_conv = ConvModule( matching_feat_dims, matching_feat_dims, 1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, bias=True, inplace=True) self.semantic_matching_pred = nn.Conv1d(matching_feat_dims, 2, 1) # Surface feature aggregation self.surface_feats_aggregation = list() for k in range(primitive_feat_refine_streams): self.surface_feats_aggregation.append( ConvModule( matching_feat_dims, matching_feat_dims, 1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, bias=True, inplace=True)) self.surface_feats_aggregation = nn.Sequential( *self.surface_feats_aggregation) # Line feature aggregation self.line_feats_aggregation = list() for k in range(primitive_feat_refine_streams): self.line_feats_aggregation.append( ConvModule( matching_feat_dims, matching_feat_dims, 1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, bias=True, inplace=True)) self.line_feats_aggregation = nn.Sequential( *self.line_feats_aggregation) # surface center(6) + line center(12) prev_channel = 18 * matching_feat_dims self.bbox_pred = nn.ModuleList() for k in range(len(primitive_refine_channels)): self.bbox_pred.append( ConvModule( prev_channel, primitive_refine_channels[k], 1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, bias=True, inplace=False)) prev_channel = primitive_refine_channels[k] # Final object detection # Objectness scores (2), center residual (3), # heading class+residual (num_heading_bin*2), size class + # residual(num_size_cluster*4) conv_out_channel = (2 + 3 + bbox_coder['num_dir_bins'] * 2 + bbox_coder['num_sizes'] * 4 + self.num_classes) self.bbox_pred.append(nn.Conv1d(prev_channel, conv_out_channel, 1))
[docs] def forward(self, feats_dict, sample_mod): """Forward pass. Args: feats_dict (dict): Feature dict from backbone. sample_mod (str): Sample mode for vote aggregation layer. valid modes are "vote", "seed" and "random". Returns: dict: Predictions of vote head. """ ret_dict = {} aggregated_points = feats_dict['aggregated_points'] original_feature = feats_dict['aggregated_features'] batch_size = original_feature.shape[0] object_proposal = original_feature.shape[2] # Extract surface center, features and semantic predictions z_center = feats_dict['pred_z_center'] xy_center = feats_dict['pred_xy_center'] z_semantic = feats_dict['sem_cls_scores_z'] xy_semantic = feats_dict['sem_cls_scores_xy'] z_feature = feats_dict['aggregated_features_z'] xy_feature = feats_dict['aggregated_features_xy'] # Extract line points and features line_center = feats_dict['pred_line_center'] line_feature = feats_dict['aggregated_features_line'] surface_center_pred = torch.cat((z_center, xy_center), dim=1) ret_dict['surface_center_pred'] = surface_center_pred ret_dict['surface_sem_pred'] = torch.cat((z_semantic, xy_semantic), dim=1) # Extract the surface and line centers of rpn proposals rpn_proposals = feats_dict['proposal_list'] rpn_proposals_bbox = DepthInstance3DBoxes( rpn_proposals.reshape(-1, 7).clone(), box_dim=rpn_proposals.shape[-1], with_yaw=self.with_angle, origin=(0.5, 0.5, 0.5)) obj_surface_center, obj_line_center = \ rpn_proposals_bbox.get_surface_line_center() obj_surface_center = obj_surface_center.reshape( batch_size, -1, 6, 3).transpose(1, 2).reshape(batch_size, -1, 3) obj_line_center = obj_line_center.reshape(batch_size, -1, 12, 3).transpose(1, 2).reshape( batch_size, -1, 3) ret_dict['surface_center_object'] = obj_surface_center ret_dict['line_center_object'] = obj_line_center # aggregate primitive z and xy features to rpn proposals surface_center_feature_pred = torch.cat((z_feature, xy_feature), dim=2) surface_center_feature_pred = torch.cat( (surface_center_feature_pred.new_zeros( (batch_size, 6, surface_center_feature_pred.shape[2])), surface_center_feature_pred), dim=1) surface_xyz, surface_features, _ = self.surface_center_matcher( surface_center_pred, surface_center_feature_pred, target_xyz=obj_surface_center) # aggregate primitive line features to rpn proposals line_feature = torch.cat((line_feature.new_zeros( (batch_size, 12, line_feature.shape[2])), line_feature), dim=1) line_xyz, line_features, _ = self.line_center_matcher( line_center, line_feature, target_xyz=obj_line_center) # combine the surface and line features combine_features = torch.cat((surface_features, line_features), dim=2) matching_features = self.matching_conv(combine_features) matching_score = self.matching_pred(matching_features) ret_dict['matching_score'] = matching_score.transpose(2, 1) semantic_matching_features = self.semantic_matching_conv( combine_features) semantic_matching_score = self.semantic_matching_pred( semantic_matching_features) ret_dict['semantic_matching_score'] = \ semantic_matching_score.transpose(2, 1) surface_features = self.surface_feats_aggregation(surface_features) line_features = self.line_feats_aggregation(line_features) # Combine all surface and line features surface_features = surface_features.view(batch_size, -1, object_proposal) line_features = line_features.view(batch_size, -1, object_proposal) combine_feature = torch.cat((surface_features, line_features), dim=1) # Final bbox predictions bbox_predictions = self.bbox_pred[0](combine_feature) bbox_predictions += original_feature for conv_module in self.bbox_pred[1:]: bbox_predictions = conv_module(bbox_predictions) refine_decode_res = self.bbox_coder.split_pred( bbox_predictions[:, :self.num_classes + 2], bbox_predictions[:, self.num_classes + 2:], aggregated_points) for key in refine_decode_res.keys(): ret_dict[key + '_optimized'] = refine_decode_res[key] return ret_dict
[docs] def loss(self, bbox_preds, points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask=None, pts_instance_mask=None, img_metas=None, rpn_targets=None, gt_bboxes_ignore=None): """Compute loss. Args: bbox_preds (dict): Predictions from forward of h3d bbox head. points (list[torch.Tensor]): Input points. gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth bboxes of each sample. gt_labels_3d (list[torch.Tensor]): Labels of each sample. pts_semantic_mask (list[torch.Tensor]): Point-wise semantic mask. pts_instance_mask (list[torch.Tensor]): Point-wise instance mask. img_metas (list[dict]): Contain pcd and img's meta info. rpn_targets (Tuple) : Targets generated by rpn head. gt_bboxes_ignore (list[torch.Tensor]): Specify which bounding. Returns: dict: Losses of H3dnet. """ (vote_targets, vote_target_masks, size_class_targets, size_res_targets, dir_class_targets, dir_res_targets, center_targets, _, mask_targets, valid_gt_masks, objectness_targets, objectness_weights, box_loss_weights, valid_gt_weights) = rpn_targets losses = {} # calculate refined proposal loss refined_proposal_loss = self.get_proposal_stage_loss( bbox_preds, size_class_targets, size_res_targets, dir_class_targets, dir_res_targets, center_targets, mask_targets, objectness_targets, objectness_weights, box_loss_weights, valid_gt_weights, suffix='_optimized') for key in refined_proposal_loss.keys(): losses[key + '_optimized'] = refined_proposal_loss[key] bbox3d_optimized = self.bbox_coder.decode( bbox_preds, suffix='_optimized') targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask, pts_instance_mask, bbox_preds) (cues_objectness_label, cues_sem_label, proposal_objectness_label, cues_mask, cues_match_mask, proposal_objectness_mask, cues_matching_label, obj_surface_line_center) = targets # match scores for each geometric primitive objectness_scores = bbox_preds['matching_score'] # match scores for the semantics of primitives objectness_scores_sem = bbox_preds['semantic_matching_score'] primitive_objectness_loss = self.cues_objectness_loss( objectness_scores.transpose(2, 1), cues_objectness_label, weight=cues_mask, avg_factor=cues_mask.sum() + 1e-6) primitive_sem_loss = self.cues_semantic_loss( objectness_scores_sem.transpose(2, 1), cues_sem_label, weight=cues_mask, avg_factor=cues_mask.sum() + 1e-6) objectness_scores = bbox_preds['obj_scores_optimized'] objectness_loss_refine = self.proposal_objectness_loss( objectness_scores.transpose(2, 1), proposal_objectness_label) primitive_matching_loss = (objectness_loss_refine * cues_match_mask).sum() / ( cues_match_mask.sum() + 1e-6) * 0.5 primitive_sem_matching_loss = ( objectness_loss_refine * proposal_objectness_mask).sum() / ( proposal_objectness_mask.sum() + 1e-6) * 0.5 # Get the object surface center here batch_size, object_proposal = bbox3d_optimized.shape[:2] refined_bbox = DepthInstance3DBoxes( bbox3d_optimized.reshape(-1, 7).clone(), box_dim=bbox3d_optimized.shape[-1], with_yaw=self.with_angle, origin=(0.5, 0.5, 0.5)) pred_obj_surface_center, pred_obj_line_center = \ refined_bbox.get_surface_line_center() pred_obj_surface_center = pred_obj_surface_center.reshape( batch_size, -1, 6, 3).transpose(1, 2).reshape(batch_size, -1, 3) pred_obj_line_center = pred_obj_line_center.reshape( batch_size, -1, 12, 3).transpose(1, 2).reshape(batch_size, -1, 3) pred_surface_line_center = torch.cat( (pred_obj_surface_center, pred_obj_line_center), 1) square_dist = self.primitive_center_loss(pred_surface_line_center, obj_surface_line_center) match_dist = torch.sqrt(square_dist.sum(dim=-1) + 1e-6) primitive_centroid_reg_loss = torch.sum( match_dist * cues_matching_label) / ( cues_matching_label.sum() + 1e-6) refined_loss = dict( primitive_objectness_loss=primitive_objectness_loss, primitive_sem_loss=primitive_sem_loss, primitive_matching_loss=primitive_matching_loss, primitive_sem_matching_loss=primitive_sem_matching_loss, primitive_centroid_reg_loss=primitive_centroid_reg_loss) losses.update(refined_loss) return losses
[docs] def get_bboxes(self, points, bbox_preds, input_metas, rescale=False, suffix=''): """Generate bboxes from vote head predictions. Args: points (torch.Tensor): Input points. bbox_preds (dict): Predictions from vote head. input_metas (list[dict]): Point cloud and image's meta info. rescale (bool): Whether to rescale bboxes. Returns: list[tuple[torch.Tensor]]: Bounding boxes, scores and labels. """ # decode boxes obj_scores = F.softmax( bbox_preds['obj_scores' + suffix], dim=-1)[..., -1] sem_scores = F.softmax(bbox_preds['sem_scores'], dim=-1) prediction_collection = {} prediction_collection['center'] = bbox_preds['center' + suffix] prediction_collection['dir_class'] = bbox_preds['dir_class'] prediction_collection['dir_res'] = bbox_preds['dir_res' + suffix] prediction_collection['size_class'] = bbox_preds['size_class'] prediction_collection['size_res'] = bbox_preds['size_res' + suffix] bbox3d = self.bbox_coder.decode(prediction_collection) batch_size = bbox3d.shape[0] results = list() for b in range(batch_size): bbox_selected, score_selected, labels = self.multiclass_nms_single( obj_scores[b], sem_scores[b], bbox3d[b], points[b, ..., :3], input_metas[b]) bbox = input_metas[b]['box_type_3d']( bbox_selected, box_dim=bbox_selected.shape[-1], with_yaw=self.bbox_coder.with_rot) results.append((bbox, score_selected, labels)) return results
[docs] def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points, input_meta): """Multi-class nms in single batch. Args: obj_scores (torch.Tensor): Objectness score of bounding boxes. sem_scores (torch.Tensor): semantic class score of bounding boxes. bbox (torch.Tensor): Predicted bounding boxes. points (torch.Tensor): Input points. input_meta (dict): Point cloud and image's meta info. Returns: tuple[torch.Tensor]: Bounding boxes, scores and labels. """ bbox = input_meta['box_type_3d']( bbox, box_dim=bbox.shape[-1], with_yaw=self.bbox_coder.with_rot, origin=(0.5, 0.5, 0.5)) box_indices = bbox.points_in_boxes_all(points) corner3d = bbox.corners minmax_box3d = corner3d.new(torch.Size((corner3d.shape[0], 6))) minmax_box3d[:, :3] = torch.min(corner3d, dim=1)[0] minmax_box3d[:, 3:] = torch.max(corner3d, dim=1)[0] nonempty_box_mask = box_indices.T.sum(1) > 5 bbox_classes = torch.argmax(sem_scores, -1) nms_selected = aligned_3d_nms(minmax_box3d[nonempty_box_mask], obj_scores[nonempty_box_mask], bbox_classes[nonempty_box_mask], self.test_cfg.nms_thr) # filter empty boxes and boxes with low score scores_mask = (obj_scores > self.test_cfg.score_thr) nonempty_box_inds = torch.nonzero( nonempty_box_mask, as_tuple=False).flatten() nonempty_mask = torch.zeros_like(bbox_classes).scatter( 0, nonempty_box_inds[nms_selected], 1) selected = (nonempty_mask.bool() & scores_mask.bool()) if self.test_cfg.per_class_proposal: bbox_selected, score_selected, labels = [], [], [] for k in range(sem_scores.shape[-1]): bbox_selected.append(bbox[selected].tensor) score_selected.append(obj_scores[selected] * sem_scores[selected][:, k]) labels.append( torch.zeros_like(bbox_classes[selected]).fill_(k)) bbox_selected = torch.cat(bbox_selected, 0) score_selected = torch.cat(score_selected, 0) labels = torch.cat(labels, 0) else: bbox_selected = bbox[selected].tensor score_selected = obj_scores[selected] labels = bbox_classes[selected] return bbox_selected, score_selected, labels
[docs] def get_proposal_stage_loss(self, bbox_preds, size_class_targets, size_res_targets, dir_class_targets, dir_res_targets, center_targets, mask_targets, objectness_targets, objectness_weights, box_loss_weights, valid_gt_weights, suffix=''): """Compute loss for the aggregation module. Args: bbox_preds (dict): Predictions from forward of vote head. size_class_targets (torch.Tensor): Ground truth size class of each prediction bounding box. size_res_targets (torch.Tensor): Ground truth size residual of each prediction bounding box. dir_class_targets (torch.Tensor): Ground truth direction class of each prediction bounding box. dir_res_targets (torch.Tensor): Ground truth direction residual of each prediction bounding box. center_targets (torch.Tensor): Ground truth center of each prediction bounding box. mask_targets (torch.Tensor): Validation of each prediction bounding box. objectness_targets (torch.Tensor): Ground truth objectness label of each prediction bounding box. objectness_weights (torch.Tensor): Weights of objectness loss for each prediction bounding box. box_loss_weights (torch.Tensor): Weights of regression loss for each prediction bounding box. valid_gt_weights (torch.Tensor): Validation of each ground truth bounding box. Returns: dict: Losses of aggregation module. """ # calculate objectness loss objectness_loss = self.objectness_loss( bbox_preds['obj_scores' + suffix].transpose(2, 1), objectness_targets, weight=objectness_weights) # calculate center loss source2target_loss, target2source_loss = self.center_loss( bbox_preds['center' + suffix], center_targets, src_weight=box_loss_weights, dst_weight=valid_gt_weights) center_loss = source2target_loss + target2source_loss # calculate direction class loss dir_class_loss = self.dir_class_loss( bbox_preds['dir_class' + suffix].transpose(2, 1), dir_class_targets, weight=box_loss_weights) # calculate direction residual loss batch_size, proposal_num = size_class_targets.shape[:2] heading_label_one_hot = dir_class_targets.new_zeros( (batch_size, proposal_num, self.num_dir_bins)) heading_label_one_hot.scatter_(2, dir_class_targets.unsqueeze(-1), 1) dir_res_norm = (bbox_preds['dir_res_norm' + suffix] * heading_label_one_hot).sum(dim=-1) dir_res_loss = self.dir_res_loss( dir_res_norm, dir_res_targets, weight=box_loss_weights) # calculate size class loss size_class_loss = self.size_class_loss( bbox_preds['size_class' + suffix].transpose(2, 1), size_class_targets, weight=box_loss_weights) # calculate size residual loss one_hot_size_targets = box_loss_weights.new_zeros( (batch_size, proposal_num, self.num_sizes)) one_hot_size_targets.scatter_(2, size_class_targets.unsqueeze(-1), 1) one_hot_size_targets_expand = one_hot_size_targets.unsqueeze( -1).repeat(1, 1, 1, 3) size_residual_norm = (bbox_preds['size_res_norm' + suffix] * one_hot_size_targets_expand).sum(dim=2) box_loss_weights_expand = box_loss_weights.unsqueeze(-1).repeat( 1, 1, 3) size_res_loss = self.size_res_loss( size_residual_norm, size_res_targets, weight=box_loss_weights_expand) # calculate semantic loss semantic_loss = self.semantic_loss( bbox_preds['sem_scores' + suffix].transpose(2, 1), mask_targets, weight=box_loss_weights) losses = dict( objectness_loss=objectness_loss, semantic_loss=semantic_loss, center_loss=center_loss, dir_class_loss=dir_class_loss, dir_res_loss=dir_res_loss, size_class_loss=size_class_loss, size_res_loss=size_res_loss) return losses
[docs] def get_targets(self, points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask=None, pts_instance_mask=None, bbox_preds=None): """Generate targets of proposal module. Args: points (list[torch.Tensor]): Points of each batch. gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth bboxes of each batch. gt_labels_3d (list[torch.Tensor]): Labels of each batch. pts_semantic_mask (list[torch.Tensor]): Point-wise semantic label of each batch. pts_instance_mask (list[torch.Tensor]): Point-wise instance label of each batch. bbox_preds (torch.Tensor): Bounding box predictions of vote head. Returns: tuple[torch.Tensor]: Targets of proposal module. """ # find empty example valid_gt_masks = list() gt_num = list() for index in range(len(gt_labels_3d)): if len(gt_labels_3d[index]) == 0: fake_box = gt_bboxes_3d[index].tensor.new_zeros( 1, gt_bboxes_3d[index].tensor.shape[-1]) gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box) gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1) valid_gt_masks.append(gt_labels_3d[index].new_zeros(1)) gt_num.append(1) else: valid_gt_masks.append(gt_labels_3d[index].new_ones( gt_labels_3d[index].shape)) gt_num.append(gt_labels_3d[index].shape[0]) if pts_semantic_mask is None: pts_semantic_mask = [None for i in range(len(gt_labels_3d))] pts_instance_mask = [None for i in range(len(gt_labels_3d))] aggregated_points = [ bbox_preds['aggregated_points'][i] for i in range(len(gt_labels_3d)) ] surface_center_pred = [ bbox_preds['surface_center_pred'][i] for i in range(len(gt_labels_3d)) ] line_center_pred = [ bbox_preds['pred_line_center'][i] for i in range(len(gt_labels_3d)) ] surface_center_object = [ bbox_preds['surface_center_object'][i] for i in range(len(gt_labels_3d)) ] line_center_object = [ bbox_preds['line_center_object'][i] for i in range(len(gt_labels_3d)) ] surface_sem_pred = [ bbox_preds['surface_sem_pred'][i] for i in range(len(gt_labels_3d)) ] line_sem_pred = [ bbox_preds['sem_cls_scores_line'][i] for i in range(len(gt_labels_3d)) ] (cues_objectness_label, cues_sem_label, proposal_objectness_label, cues_mask, cues_match_mask, proposal_objectness_mask, cues_matching_label, obj_surface_line_center) = multi_apply( self.get_targets_single, points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask, pts_instance_mask, aggregated_points, surface_center_pred, line_center_pred, surface_center_object, line_center_object, surface_sem_pred, line_sem_pred) cues_objectness_label = torch.stack(cues_objectness_label) cues_sem_label = torch.stack(cues_sem_label) proposal_objectness_label = torch.stack(proposal_objectness_label) cues_mask = torch.stack(cues_mask) cues_match_mask = torch.stack(cues_match_mask) proposal_objectness_mask = torch.stack(proposal_objectness_mask) cues_matching_label = torch.stack(cues_matching_label) obj_surface_line_center = torch.stack(obj_surface_line_center) return (cues_objectness_label, cues_sem_label, proposal_objectness_label, cues_mask, cues_match_mask, proposal_objectness_mask, cues_matching_label, obj_surface_line_center)
[docs] def get_targets_single(self, points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask=None, pts_instance_mask=None, aggregated_points=None, pred_surface_center=None, pred_line_center=None, pred_obj_surface_center=None, pred_obj_line_center=None, pred_surface_sem=None, pred_line_sem=None): """Generate targets for primitive cues for single batch. Args: points (torch.Tensor): Points of each batch. gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth boxes of each batch. gt_labels_3d (torch.Tensor): Labels of each batch. pts_semantic_mask (torch.Tensor): Point-wise semantic label of each batch. pts_instance_mask (torch.Tensor): Point-wise instance label of each batch. aggregated_points (torch.Tensor): Aggregated points from vote aggregation layer. pred_surface_center (torch.Tensor): Prediction of surface center. pred_line_center (torch.Tensor): Prediction of line center. pred_obj_surface_center (torch.Tensor): Objectness prediction of surface center. pred_obj_line_center (torch.Tensor): Objectness prediction of line center. pred_surface_sem (torch.Tensor): Semantic prediction of surface center. pred_line_sem (torch.Tensor): Semantic prediction of line center. Returns: tuple[torch.Tensor]: Targets for primitive cues. """ device = points.device gt_bboxes_3d = gt_bboxes_3d.to(device) num_proposals = aggregated_points.shape[0] gt_center = gt_bboxes_3d.gravity_center dist1, dist2, ind1, _ = chamfer_distance( aggregated_points.unsqueeze(0), gt_center.unsqueeze(0), reduction='none') # Set assignment object_assignment = ind1.squeeze(0) # Generate objectness label and mask # objectness_label: 1 if pred object center is within # self.train_cfg['near_threshold'] of any GT object # objectness_mask: 0 if pred object center is in gray # zone (DONOTCARE), 1 otherwise euclidean_dist1 = torch.sqrt(dist1.squeeze(0) + 1e-6) proposal_objectness_label = euclidean_dist1.new_zeros( num_proposals, dtype=torch.long) proposal_objectness_mask = euclidean_dist1.new_zeros(num_proposals) gt_sem = gt_labels_3d[object_assignment] obj_surface_center, obj_line_center = \ gt_bboxes_3d.get_surface_line_center() obj_surface_center = obj_surface_center.reshape(-1, 6, 3).transpose(0, 1) obj_line_center = obj_line_center.reshape(-1, 12, 3).transpose(0, 1) obj_surface_center = obj_surface_center[:, object_assignment].reshape( 1, -1, 3) obj_line_center = obj_line_center[:, object_assignment].reshape(1, -1, 3) surface_sem = torch.argmax(pred_surface_sem, dim=1).float() line_sem = torch.argmax(pred_line_sem, dim=1).float() dist_surface, _, surface_ind, _ = chamfer_distance( obj_surface_center, pred_surface_center.unsqueeze(0), reduction='none') dist_line, _, line_ind, _ = chamfer_distance( obj_line_center, pred_line_center.unsqueeze(0), reduction='none') surface_sel = pred_surface_center[surface_ind.squeeze(0)] line_sel = pred_line_center[line_ind.squeeze(0)] surface_sel_sem = surface_sem[surface_ind.squeeze(0)] line_sel_sem = line_sem[line_ind.squeeze(0)] surface_sel_sem_gt = gt_sem.repeat(6).float() line_sel_sem_gt = gt_sem.repeat(12).float() euclidean_dist_surface = torch.sqrt(dist_surface.squeeze(0) + 1e-6) euclidean_dist_line = torch.sqrt(dist_line.squeeze(0) + 1e-6) objectness_label_surface = euclidean_dist_line.new_zeros( num_proposals * 6, dtype=torch.long) objectness_mask_surface = euclidean_dist_line.new_zeros(num_proposals * 6) objectness_label_line = euclidean_dist_line.new_zeros( num_proposals * 12, dtype=torch.long) objectness_mask_line = euclidean_dist_line.new_zeros(num_proposals * 12) objectness_label_surface_sem = euclidean_dist_line.new_zeros( num_proposals * 6, dtype=torch.long) objectness_label_line_sem = euclidean_dist_line.new_zeros( num_proposals * 12, dtype=torch.long) euclidean_dist_obj_surface = torch.sqrt(( (pred_obj_surface_center - surface_sel)**2).sum(dim=-1) + 1e-6) euclidean_dist_obj_line = torch.sqrt( torch.sum((pred_obj_line_center - line_sel)**2, dim=-1) + 1e-6) # Objectness score just with centers proposal_objectness_label[ euclidean_dist1 < self.train_cfg['near_threshold']] = 1 proposal_objectness_mask[ euclidean_dist1 < self.train_cfg['near_threshold']] = 1 proposal_objectness_mask[ euclidean_dist1 > self.train_cfg['far_threshold']] = 1 objectness_label_surface[ (euclidean_dist_obj_surface < self.train_cfg['label_surface_threshold']) * (euclidean_dist_surface < self.train_cfg['mask_surface_threshold'])] = 1 objectness_label_surface_sem[ (euclidean_dist_obj_surface < self.train_cfg['label_surface_threshold']) * (euclidean_dist_surface < self.train_cfg['mask_surface_threshold']) * (surface_sel_sem == surface_sel_sem_gt)] = 1 objectness_label_line[ (euclidean_dist_obj_line < self.train_cfg['label_line_threshold']) * (euclidean_dist_line < self.train_cfg['mask_line_threshold'])] = 1 objectness_label_line_sem[ (euclidean_dist_obj_line < self.train_cfg['label_line_threshold']) * (euclidean_dist_line < self.train_cfg['mask_line_threshold']) * (line_sel_sem == line_sel_sem_gt)] = 1 objectness_label_surface_obj = proposal_objectness_label.repeat(6) objectness_mask_surface_obj = proposal_objectness_mask.repeat(6) objectness_label_line_obj = proposal_objectness_label.repeat(12) objectness_mask_line_obj = proposal_objectness_mask.repeat(12) objectness_mask_surface = objectness_mask_surface_obj objectness_mask_line = objectness_mask_line_obj cues_objectness_label = torch.cat( (objectness_label_surface, objectness_label_line), 0) cues_sem_label = torch.cat( (objectness_label_surface_sem, objectness_label_line_sem), 0) cues_mask = torch.cat((objectness_mask_surface, objectness_mask_line), 0) objectness_label_surface *= objectness_label_surface_obj objectness_label_line *= objectness_label_line_obj cues_matching_label = torch.cat( (objectness_label_surface, objectness_label_line), 0) objectness_label_surface_sem *= objectness_label_surface_obj objectness_label_line_sem *= objectness_label_line_obj cues_match_mask = (torch.sum( cues_objectness_label.view(18, num_proposals), dim=0) >= 1).float() obj_surface_line_center = torch.cat( (obj_surface_center, obj_line_center), 1).squeeze(0) return (cues_objectness_label, cues_sem_label, proposal_objectness_label, cues_mask, cues_match_mask, proposal_objectness_mask, cues_matching_label, obj_surface_line_center)
Read the Docs v: dev
Versions
latest
stable
v1.0.0rc1
v1.0.0rc0
v0.18.1
v0.18.0
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.