Source code for mmdet3d.models.roi_heads.bbox_heads.h3d_bbox_head
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from torch import nn as nn
from torch.nn import functional as F
from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet3d.core.post_processing import aligned_3d_nms
from mmdet3d.models.builder import HEADS, build_loss
from mmdet3d.models.losses import chamfer_distance
from mmdet3d.ops import build_sa_module
from mmdet.core import build_bbox_coder, multi_apply
[docs]@HEADS.register_module()
class H3DBboxHead(BaseModule):
r"""Bbox head of `H3DNet <https://arxiv.org/abs/2006.05682>`_.
Args:
num_classes (int): The number of classes.
surface_matching_cfg (dict): Config for surface primitive matching.
line_matching_cfg (dict): Config for line primitive matching.
bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and
decoding boxes.
train_cfg (dict): Config for training.
test_cfg (dict): Config for testing.
gt_per_seed (int): Number of ground truth votes generated
from each seed point.
num_proposal (int): Number of proposal votes generated.
feat_channels (tuple[int]): Convolution channels of
prediction layer.
primitive_feat_refine_streams (int): The number of mlps to
refine primitive feature.
primitive_refine_channels (tuple[int]): Convolution channels of
prediction layer.
upper_thresh (float): Threshold for line matching.
surface_thresh (float): Threshold for surface matching.
line_thresh (float): Threshold for line matching.
conv_cfg (dict): Config of convolution in prediction layer.
norm_cfg (dict): Config of BN in prediction layer.
objectness_loss (dict): Config of objectness loss.
center_loss (dict): Config of center loss.
dir_class_loss (dict): Config of direction classification loss.
dir_res_loss (dict): Config of direction residual regression loss.
size_class_loss (dict): Config of size classification loss.
size_res_loss (dict): Config of size residual regression loss.
semantic_loss (dict): Config of point-wise semantic segmentation loss.
cues_objectness_loss (dict): Config of cues objectness loss.
cues_semantic_loss (dict): Config of cues semantic loss.
proposal_objectness_loss (dict): Config of proposal objectness
loss.
primitive_center_loss (dict): Config of primitive center regression
loss.
"""
def __init__(self,
num_classes,
suface_matching_cfg,
line_matching_cfg,
bbox_coder,
train_cfg=None,
test_cfg=None,
gt_per_seed=1,
num_proposal=256,
feat_channels=(128, 128),
primitive_feat_refine_streams=2,
primitive_refine_channels=[128, 128, 128],
upper_thresh=100.0,
surface_thresh=0.5,
line_thresh=0.5,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=None,
center_loss=None,
dir_class_loss=None,
dir_res_loss=None,
size_class_loss=None,
size_res_loss=None,
semantic_loss=None,
cues_objectness_loss=None,
cues_semantic_loss=None,
proposal_objectness_loss=None,
primitive_center_loss=None,
init_cfg=None):
super(H3DBboxHead, self).__init__(init_cfg=init_cfg)
self.num_classes = num_classes
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.gt_per_seed = gt_per_seed
self.num_proposal = num_proposal
self.with_angle = bbox_coder['with_rot']
self.upper_thresh = upper_thresh
self.surface_thresh = surface_thresh
self.line_thresh = line_thresh
self.objectness_loss = build_loss(objectness_loss)
self.center_loss = build_loss(center_loss)
self.dir_class_loss = build_loss(dir_class_loss)
self.dir_res_loss = build_loss(dir_res_loss)
self.size_class_loss = build_loss(size_class_loss)
self.size_res_loss = build_loss(size_res_loss)
self.semantic_loss = build_loss(semantic_loss)
self.bbox_coder = build_bbox_coder(bbox_coder)
self.num_sizes = self.bbox_coder.num_sizes
self.num_dir_bins = self.bbox_coder.num_dir_bins
self.cues_objectness_loss = build_loss(cues_objectness_loss)
self.cues_semantic_loss = build_loss(cues_semantic_loss)
self.proposal_objectness_loss = build_loss(proposal_objectness_loss)
self.primitive_center_loss = build_loss(primitive_center_loss)
assert suface_matching_cfg['mlp_channels'][-1] == \
line_matching_cfg['mlp_channels'][-1]
# surface center matching
self.surface_center_matcher = build_sa_module(suface_matching_cfg)
# line center matching
self.line_center_matcher = build_sa_module(line_matching_cfg)
# Compute the matching scores
matching_feat_dims = suface_matching_cfg['mlp_channels'][-1]
self.matching_conv = ConvModule(
matching_feat_dims,
matching_feat_dims,
1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=True,
inplace=True)
self.matching_pred = nn.Conv1d(matching_feat_dims, 2, 1)
# Compute the semantic matching scores
self.semantic_matching_conv = ConvModule(
matching_feat_dims,
matching_feat_dims,
1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=True,
inplace=True)
self.semantic_matching_pred = nn.Conv1d(matching_feat_dims, 2, 1)
# Surface feature aggregation
self.surface_feats_aggregation = list()
for k in range(primitive_feat_refine_streams):
self.surface_feats_aggregation.append(
ConvModule(
matching_feat_dims,
matching_feat_dims,
1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=True,
inplace=True))
self.surface_feats_aggregation = nn.Sequential(
*self.surface_feats_aggregation)
# Line feature aggregation
self.line_feats_aggregation = list()
for k in range(primitive_feat_refine_streams):
self.line_feats_aggregation.append(
ConvModule(
matching_feat_dims,
matching_feat_dims,
1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=True,
inplace=True))
self.line_feats_aggregation = nn.Sequential(
*self.line_feats_aggregation)
# surface center(6) + line center(12)
prev_channel = 18 * matching_feat_dims
self.bbox_pred = nn.ModuleList()
for k in range(len(primitive_refine_channels)):
self.bbox_pred.append(
ConvModule(
prev_channel,
primitive_refine_channels[k],
1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=True,
inplace=False))
prev_channel = primitive_refine_channels[k]
# Final object detection
# Objectness scores (2), center residual (3),
# heading class+residual (num_heading_bin*2), size class +
# residual(num_size_cluster*4)
conv_out_channel = (2 + 3 + bbox_coder['num_dir_bins'] * 2 +
bbox_coder['num_sizes'] * 4 + self.num_classes)
self.bbox_pred.append(nn.Conv1d(prev_channel, conv_out_channel, 1))
[docs] def forward(self, feats_dict, sample_mod):
"""Forward pass.
Args:
feats_dict (dict): Feature dict from backbone.
sample_mod (str): Sample mode for vote aggregation layer.
valid modes are "vote", "seed" and "random".
Returns:
dict: Predictions of vote head.
"""
ret_dict = {}
aggregated_points = feats_dict['aggregated_points']
original_feature = feats_dict['aggregated_features']
batch_size = original_feature.shape[0]
object_proposal = original_feature.shape[2]
# Extract surface center, features and semantic predictions
z_center = feats_dict['pred_z_center']
xy_center = feats_dict['pred_xy_center']
z_semantic = feats_dict['sem_cls_scores_z']
xy_semantic = feats_dict['sem_cls_scores_xy']
z_feature = feats_dict['aggregated_features_z']
xy_feature = feats_dict['aggregated_features_xy']
# Extract line points and features
line_center = feats_dict['pred_line_center']
line_feature = feats_dict['aggregated_features_line']
surface_center_pred = torch.cat((z_center, xy_center), dim=1)
ret_dict['surface_center_pred'] = surface_center_pred
ret_dict['surface_sem_pred'] = torch.cat((z_semantic, xy_semantic),
dim=1)
# Extract the surface and line centers of rpn proposals
rpn_proposals = feats_dict['proposal_list']
rpn_proposals_bbox = DepthInstance3DBoxes(
rpn_proposals.reshape(-1, 7).clone(),
box_dim=rpn_proposals.shape[-1],
with_yaw=self.with_angle,
origin=(0.5, 0.5, 0.5))
obj_surface_center, obj_line_center = \
rpn_proposals_bbox.get_surface_line_center()
obj_surface_center = obj_surface_center.reshape(
batch_size, -1, 6, 3).transpose(1, 2).reshape(batch_size, -1, 3)
obj_line_center = obj_line_center.reshape(batch_size, -1, 12,
3).transpose(1, 2).reshape(
batch_size, -1, 3)
ret_dict['surface_center_object'] = obj_surface_center
ret_dict['line_center_object'] = obj_line_center
# aggregate primitive z and xy features to rpn proposals
surface_center_feature_pred = torch.cat((z_feature, xy_feature), dim=2)
surface_center_feature_pred = torch.cat(
(surface_center_feature_pred.new_zeros(
(batch_size, 6, surface_center_feature_pred.shape[2])),
surface_center_feature_pred),
dim=1)
surface_xyz, surface_features, _ = self.surface_center_matcher(
surface_center_pred,
surface_center_feature_pred,
target_xyz=obj_surface_center)
# aggregate primitive line features to rpn proposals
line_feature = torch.cat((line_feature.new_zeros(
(batch_size, 12, line_feature.shape[2])), line_feature),
dim=1)
line_xyz, line_features, _ = self.line_center_matcher(
line_center, line_feature, target_xyz=obj_line_center)
# combine the surface and line features
combine_features = torch.cat((surface_features, line_features), dim=2)
matching_features = self.matching_conv(combine_features)
matching_score = self.matching_pred(matching_features)
ret_dict['matching_score'] = matching_score.transpose(2, 1)
semantic_matching_features = self.semantic_matching_conv(
combine_features)
semantic_matching_score = self.semantic_matching_pred(
semantic_matching_features)
ret_dict['semantic_matching_score'] = \
semantic_matching_score.transpose(2, 1)
surface_features = self.surface_feats_aggregation(surface_features)
line_features = self.line_feats_aggregation(line_features)
# Combine all surface and line features
surface_features = surface_features.view(batch_size, -1,
object_proposal)
line_features = line_features.view(batch_size, -1, object_proposal)
combine_feature = torch.cat((surface_features, line_features), dim=1)
# Final bbox predictions
bbox_predictions = self.bbox_pred[0](combine_feature)
bbox_predictions += original_feature
for conv_module in self.bbox_pred[1:]:
bbox_predictions = conv_module(bbox_predictions)
refine_decode_res = self.bbox_coder.split_pred(
bbox_predictions[:, :self.num_classes + 2],
bbox_predictions[:, self.num_classes + 2:], aggregated_points)
for key in refine_decode_res.keys():
ret_dict[key + '_optimized'] = refine_decode_res[key]
return ret_dict
[docs] def loss(self,
bbox_preds,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
img_metas=None,
rpn_targets=None,
gt_bboxes_ignore=None):
"""Compute loss.
Args:
bbox_preds (dict): Predictions from forward of h3d bbox head.
points (list[torch.Tensor]): Input points.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
bboxes of each sample.
gt_labels_3d (list[torch.Tensor]): Labels of each sample.
pts_semantic_mask (list[torch.Tensor]): Point-wise
semantic mask.
pts_instance_mask (list[torch.Tensor]): Point-wise
instance mask.
img_metas (list[dict]): Contain pcd and img's meta info.
rpn_targets (Tuple) : Targets generated by rpn head.
gt_bboxes_ignore (list[torch.Tensor]): Specify
which bounding.
Returns:
dict: Losses of H3dnet.
"""
(vote_targets, vote_target_masks, size_class_targets, size_res_targets,
dir_class_targets, dir_res_targets, center_targets, _, mask_targets,
valid_gt_masks, objectness_targets, objectness_weights,
box_loss_weights, valid_gt_weights) = rpn_targets
losses = {}
# calculate refined proposal loss
refined_proposal_loss = self.get_proposal_stage_loss(
bbox_preds,
size_class_targets,
size_res_targets,
dir_class_targets,
dir_res_targets,
center_targets,
mask_targets,
objectness_targets,
objectness_weights,
box_loss_weights,
valid_gt_weights,
suffix='_optimized')
for key in refined_proposal_loss.keys():
losses[key + '_optimized'] = refined_proposal_loss[key]
bbox3d_optimized = self.bbox_coder.decode(
bbox_preds, suffix='_optimized')
targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask,
bbox_preds)
(cues_objectness_label, cues_sem_label, proposal_objectness_label,
cues_mask, cues_match_mask, proposal_objectness_mask,
cues_matching_label, obj_surface_line_center) = targets
# match scores for each geometric primitive
objectness_scores = bbox_preds['matching_score']
# match scores for the semantics of primitives
objectness_scores_sem = bbox_preds['semantic_matching_score']
primitive_objectness_loss = self.cues_objectness_loss(
objectness_scores.transpose(2, 1),
cues_objectness_label,
weight=cues_mask,
avg_factor=cues_mask.sum() + 1e-6)
primitive_sem_loss = self.cues_semantic_loss(
objectness_scores_sem.transpose(2, 1),
cues_sem_label,
weight=cues_mask,
avg_factor=cues_mask.sum() + 1e-6)
objectness_scores = bbox_preds['obj_scores_optimized']
objectness_loss_refine = self.proposal_objectness_loss(
objectness_scores.transpose(2, 1), proposal_objectness_label)
primitive_matching_loss = (objectness_loss_refine *
cues_match_mask).sum() / (
cues_match_mask.sum() + 1e-6) * 0.5
primitive_sem_matching_loss = (
objectness_loss_refine * proposal_objectness_mask).sum() / (
proposal_objectness_mask.sum() + 1e-6) * 0.5
# Get the object surface center here
batch_size, object_proposal = bbox3d_optimized.shape[:2]
refined_bbox = DepthInstance3DBoxes(
bbox3d_optimized.reshape(-1, 7).clone(),
box_dim=bbox3d_optimized.shape[-1],
with_yaw=self.with_angle,
origin=(0.5, 0.5, 0.5))
pred_obj_surface_center, pred_obj_line_center = \
refined_bbox.get_surface_line_center()
pred_obj_surface_center = pred_obj_surface_center.reshape(
batch_size, -1, 6, 3).transpose(1, 2).reshape(batch_size, -1, 3)
pred_obj_line_center = pred_obj_line_center.reshape(
batch_size, -1, 12, 3).transpose(1, 2).reshape(batch_size, -1, 3)
pred_surface_line_center = torch.cat(
(pred_obj_surface_center, pred_obj_line_center), 1)
square_dist = self.primitive_center_loss(pred_surface_line_center,
obj_surface_line_center)
match_dist = torch.sqrt(square_dist.sum(dim=-1) + 1e-6)
primitive_centroid_reg_loss = torch.sum(
match_dist * cues_matching_label) / (
cues_matching_label.sum() + 1e-6)
refined_loss = dict(
primitive_objectness_loss=primitive_objectness_loss,
primitive_sem_loss=primitive_sem_loss,
primitive_matching_loss=primitive_matching_loss,
primitive_sem_matching_loss=primitive_sem_matching_loss,
primitive_centroid_reg_loss=primitive_centroid_reg_loss)
losses.update(refined_loss)
return losses
[docs] def get_bboxes(self,
points,
bbox_preds,
input_metas,
rescale=False,
suffix=''):
"""Generate bboxes from vote head predictions.
Args:
points (torch.Tensor): Input points.
bbox_preds (dict): Predictions from vote head.
input_metas (list[dict]): Point cloud and image's meta info.
rescale (bool): Whether to rescale bboxes.
Returns:
list[tuple[torch.Tensor]]: Bounding boxes, scores and labels.
"""
# decode boxes
obj_scores = F.softmax(
bbox_preds['obj_scores' + suffix], dim=-1)[..., -1]
sem_scores = F.softmax(bbox_preds['sem_scores'], dim=-1)
prediction_collection = {}
prediction_collection['center'] = bbox_preds['center' + suffix]
prediction_collection['dir_class'] = bbox_preds['dir_class']
prediction_collection['dir_res'] = bbox_preds['dir_res' + suffix]
prediction_collection['size_class'] = bbox_preds['size_class']
prediction_collection['size_res'] = bbox_preds['size_res' + suffix]
bbox3d = self.bbox_coder.decode(prediction_collection)
batch_size = bbox3d.shape[0]
results = list()
for b in range(batch_size):
bbox_selected, score_selected, labels = self.multiclass_nms_single(
obj_scores[b], sem_scores[b], bbox3d[b], points[b, ..., :3],
input_metas[b])
bbox = input_metas[b]['box_type_3d'](
bbox_selected,
box_dim=bbox_selected.shape[-1],
with_yaw=self.bbox_coder.with_rot)
results.append((bbox, score_selected, labels))
return results
[docs] def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points,
input_meta):
"""Multi-class nms in single batch.
Args:
obj_scores (torch.Tensor): Objectness score of bounding boxes.
sem_scores (torch.Tensor): semantic class score of bounding boxes.
bbox (torch.Tensor): Predicted bounding boxes.
points (torch.Tensor): Input points.
input_meta (dict): Point cloud and image's meta info.
Returns:
tuple[torch.Tensor]: Bounding boxes, scores and labels.
"""
bbox = input_meta['box_type_3d'](
bbox,
box_dim=bbox.shape[-1],
with_yaw=self.bbox_coder.with_rot,
origin=(0.5, 0.5, 0.5))
box_indices = bbox.points_in_boxes_all(points)
corner3d = bbox.corners
minmax_box3d = corner3d.new(torch.Size((corner3d.shape[0], 6)))
minmax_box3d[:, :3] = torch.min(corner3d, dim=1)[0]
minmax_box3d[:, 3:] = torch.max(corner3d, dim=1)[0]
nonempty_box_mask = box_indices.T.sum(1) > 5
bbox_classes = torch.argmax(sem_scores, -1)
nms_selected = aligned_3d_nms(minmax_box3d[nonempty_box_mask],
obj_scores[nonempty_box_mask],
bbox_classes[nonempty_box_mask],
self.test_cfg.nms_thr)
# filter empty boxes and boxes with low score
scores_mask = (obj_scores > self.test_cfg.score_thr)
nonempty_box_inds = torch.nonzero(
nonempty_box_mask, as_tuple=False).flatten()
nonempty_mask = torch.zeros_like(bbox_classes).scatter(
0, nonempty_box_inds[nms_selected], 1)
selected = (nonempty_mask.bool() & scores_mask.bool())
if self.test_cfg.per_class_proposal:
bbox_selected, score_selected, labels = [], [], []
for k in range(sem_scores.shape[-1]):
bbox_selected.append(bbox[selected].tensor)
score_selected.append(obj_scores[selected] *
sem_scores[selected][:, k])
labels.append(
torch.zeros_like(bbox_classes[selected]).fill_(k))
bbox_selected = torch.cat(bbox_selected, 0)
score_selected = torch.cat(score_selected, 0)
labels = torch.cat(labels, 0)
else:
bbox_selected = bbox[selected].tensor
score_selected = obj_scores[selected]
labels = bbox_classes[selected]
return bbox_selected, score_selected, labels
[docs] def get_proposal_stage_loss(self,
bbox_preds,
size_class_targets,
size_res_targets,
dir_class_targets,
dir_res_targets,
center_targets,
mask_targets,
objectness_targets,
objectness_weights,
box_loss_weights,
valid_gt_weights,
suffix=''):
"""Compute loss for the aggregation module.
Args:
bbox_preds (dict): Predictions from forward of vote head.
size_class_targets (torch.Tensor): Ground truth
size class of each prediction bounding box.
size_res_targets (torch.Tensor): Ground truth
size residual of each prediction bounding box.
dir_class_targets (torch.Tensor): Ground truth
direction class of each prediction bounding box.
dir_res_targets (torch.Tensor): Ground truth
direction residual of each prediction bounding box.
center_targets (torch.Tensor): Ground truth center
of each prediction bounding box.
mask_targets (torch.Tensor): Validation of each
prediction bounding box.
objectness_targets (torch.Tensor): Ground truth
objectness label of each prediction bounding box.
objectness_weights (torch.Tensor): Weights of objectness
loss for each prediction bounding box.
box_loss_weights (torch.Tensor): Weights of regression
loss for each prediction bounding box.
valid_gt_weights (torch.Tensor): Validation of each
ground truth bounding box.
Returns:
dict: Losses of aggregation module.
"""
# calculate objectness loss
objectness_loss = self.objectness_loss(
bbox_preds['obj_scores' + suffix].transpose(2, 1),
objectness_targets,
weight=objectness_weights)
# calculate center loss
source2target_loss, target2source_loss = self.center_loss(
bbox_preds['center' + suffix],
center_targets,
src_weight=box_loss_weights,
dst_weight=valid_gt_weights)
center_loss = source2target_loss + target2source_loss
# calculate direction class loss
dir_class_loss = self.dir_class_loss(
bbox_preds['dir_class' + suffix].transpose(2, 1),
dir_class_targets,
weight=box_loss_weights)
# calculate direction residual loss
batch_size, proposal_num = size_class_targets.shape[:2]
heading_label_one_hot = dir_class_targets.new_zeros(
(batch_size, proposal_num, self.num_dir_bins))
heading_label_one_hot.scatter_(2, dir_class_targets.unsqueeze(-1), 1)
dir_res_norm = (bbox_preds['dir_res_norm' + suffix] *
heading_label_one_hot).sum(dim=-1)
dir_res_loss = self.dir_res_loss(
dir_res_norm, dir_res_targets, weight=box_loss_weights)
# calculate size class loss
size_class_loss = self.size_class_loss(
bbox_preds['size_class' + suffix].transpose(2, 1),
size_class_targets,
weight=box_loss_weights)
# calculate size residual loss
one_hot_size_targets = box_loss_weights.new_zeros(
(batch_size, proposal_num, self.num_sizes))
one_hot_size_targets.scatter_(2, size_class_targets.unsqueeze(-1), 1)
one_hot_size_targets_expand = one_hot_size_targets.unsqueeze(
-1).repeat(1, 1, 1, 3)
size_residual_norm = (bbox_preds['size_res_norm' + suffix] *
one_hot_size_targets_expand).sum(dim=2)
box_loss_weights_expand = box_loss_weights.unsqueeze(-1).repeat(
1, 1, 3)
size_res_loss = self.size_res_loss(
size_residual_norm,
size_res_targets,
weight=box_loss_weights_expand)
# calculate semantic loss
semantic_loss = self.semantic_loss(
bbox_preds['sem_scores' + suffix].transpose(2, 1),
mask_targets,
weight=box_loss_weights)
losses = dict(
objectness_loss=objectness_loss,
semantic_loss=semantic_loss,
center_loss=center_loss,
dir_class_loss=dir_class_loss,
dir_res_loss=dir_res_loss,
size_class_loss=size_class_loss,
size_res_loss=size_res_loss)
return losses
[docs] def get_targets(self,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
bbox_preds=None):
"""Generate targets of proposal module.
Args:
points (list[torch.Tensor]): Points of each batch.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): Labels of each batch.
pts_semantic_mask (list[torch.Tensor]): Point-wise semantic
label of each batch.
pts_instance_mask (list[torch.Tensor]): Point-wise instance
label of each batch.
bbox_preds (torch.Tensor): Bounding box predictions of vote head.
Returns:
tuple[torch.Tensor]: Targets of proposal module.
"""
# find empty example
valid_gt_masks = list()
gt_num = list()
for index in range(len(gt_labels_3d)):
if len(gt_labels_3d[index]) == 0:
fake_box = gt_bboxes_3d[index].tensor.new_zeros(
1, gt_bboxes_3d[index].tensor.shape[-1])
gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box)
gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1)
valid_gt_masks.append(gt_labels_3d[index].new_zeros(1))
gt_num.append(1)
else:
valid_gt_masks.append(gt_labels_3d[index].new_ones(
gt_labels_3d[index].shape))
gt_num.append(gt_labels_3d[index].shape[0])
if pts_semantic_mask is None:
pts_semantic_mask = [None for i in range(len(gt_labels_3d))]
pts_instance_mask = [None for i in range(len(gt_labels_3d))]
aggregated_points = [
bbox_preds['aggregated_points'][i]
for i in range(len(gt_labels_3d))
]
surface_center_pred = [
bbox_preds['surface_center_pred'][i]
for i in range(len(gt_labels_3d))
]
line_center_pred = [
bbox_preds['pred_line_center'][i]
for i in range(len(gt_labels_3d))
]
surface_center_object = [
bbox_preds['surface_center_object'][i]
for i in range(len(gt_labels_3d))
]
line_center_object = [
bbox_preds['line_center_object'][i]
for i in range(len(gt_labels_3d))
]
surface_sem_pred = [
bbox_preds['surface_sem_pred'][i]
for i in range(len(gt_labels_3d))
]
line_sem_pred = [
bbox_preds['sem_cls_scores_line'][i]
for i in range(len(gt_labels_3d))
]
(cues_objectness_label, cues_sem_label, proposal_objectness_label,
cues_mask, cues_match_mask, proposal_objectness_mask,
cues_matching_label, obj_surface_line_center) = multi_apply(
self.get_targets_single, points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask, aggregated_points,
surface_center_pred, line_center_pred, surface_center_object,
line_center_object, surface_sem_pred, line_sem_pred)
cues_objectness_label = torch.stack(cues_objectness_label)
cues_sem_label = torch.stack(cues_sem_label)
proposal_objectness_label = torch.stack(proposal_objectness_label)
cues_mask = torch.stack(cues_mask)
cues_match_mask = torch.stack(cues_match_mask)
proposal_objectness_mask = torch.stack(proposal_objectness_mask)
cues_matching_label = torch.stack(cues_matching_label)
obj_surface_line_center = torch.stack(obj_surface_line_center)
return (cues_objectness_label, cues_sem_label,
proposal_objectness_label, cues_mask, cues_match_mask,
proposal_objectness_mask, cues_matching_label,
obj_surface_line_center)
[docs] def get_targets_single(self,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
aggregated_points=None,
pred_surface_center=None,
pred_line_center=None,
pred_obj_surface_center=None,
pred_obj_line_center=None,
pred_surface_sem=None,
pred_line_sem=None):
"""Generate targets for primitive cues for single batch.
Args:
points (torch.Tensor): Points of each batch.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth
boxes of each batch.
gt_labels_3d (torch.Tensor): Labels of each batch.
pts_semantic_mask (torch.Tensor): Point-wise semantic
label of each batch.
pts_instance_mask (torch.Tensor): Point-wise instance
label of each batch.
aggregated_points (torch.Tensor): Aggregated points from
vote aggregation layer.
pred_surface_center (torch.Tensor): Prediction of surface center.
pred_line_center (torch.Tensor): Prediction of line center.
pred_obj_surface_center (torch.Tensor): Objectness prediction
of surface center.
pred_obj_line_center (torch.Tensor): Objectness prediction of
line center.
pred_surface_sem (torch.Tensor): Semantic prediction of
surface center.
pred_line_sem (torch.Tensor): Semantic prediction of line center.
Returns:
tuple[torch.Tensor]: Targets for primitive cues.
"""
device = points.device
gt_bboxes_3d = gt_bboxes_3d.to(device)
num_proposals = aggregated_points.shape[0]
gt_center = gt_bboxes_3d.gravity_center
dist1, dist2, ind1, _ = chamfer_distance(
aggregated_points.unsqueeze(0),
gt_center.unsqueeze(0),
reduction='none')
# Set assignment
object_assignment = ind1.squeeze(0)
# Generate objectness label and mask
# objectness_label: 1 if pred object center is within
# self.train_cfg['near_threshold'] of any GT object
# objectness_mask: 0 if pred object center is in gray
# zone (DONOTCARE), 1 otherwise
euclidean_dist1 = torch.sqrt(dist1.squeeze(0) + 1e-6)
proposal_objectness_label = euclidean_dist1.new_zeros(
num_proposals, dtype=torch.long)
proposal_objectness_mask = euclidean_dist1.new_zeros(num_proposals)
gt_sem = gt_labels_3d[object_assignment]
obj_surface_center, obj_line_center = \
gt_bboxes_3d.get_surface_line_center()
obj_surface_center = obj_surface_center.reshape(-1, 6,
3).transpose(0, 1)
obj_line_center = obj_line_center.reshape(-1, 12, 3).transpose(0, 1)
obj_surface_center = obj_surface_center[:, object_assignment].reshape(
1, -1, 3)
obj_line_center = obj_line_center[:,
object_assignment].reshape(1, -1, 3)
surface_sem = torch.argmax(pred_surface_sem, dim=1).float()
line_sem = torch.argmax(pred_line_sem, dim=1).float()
dist_surface, _, surface_ind, _ = chamfer_distance(
obj_surface_center,
pred_surface_center.unsqueeze(0),
reduction='none')
dist_line, _, line_ind, _ = chamfer_distance(
obj_line_center, pred_line_center.unsqueeze(0), reduction='none')
surface_sel = pred_surface_center[surface_ind.squeeze(0)]
line_sel = pred_line_center[line_ind.squeeze(0)]
surface_sel_sem = surface_sem[surface_ind.squeeze(0)]
line_sel_sem = line_sem[line_ind.squeeze(0)]
surface_sel_sem_gt = gt_sem.repeat(6).float()
line_sel_sem_gt = gt_sem.repeat(12).float()
euclidean_dist_surface = torch.sqrt(dist_surface.squeeze(0) + 1e-6)
euclidean_dist_line = torch.sqrt(dist_line.squeeze(0) + 1e-6)
objectness_label_surface = euclidean_dist_line.new_zeros(
num_proposals * 6, dtype=torch.long)
objectness_mask_surface = euclidean_dist_line.new_zeros(num_proposals *
6)
objectness_label_line = euclidean_dist_line.new_zeros(
num_proposals * 12, dtype=torch.long)
objectness_mask_line = euclidean_dist_line.new_zeros(num_proposals *
12)
objectness_label_surface_sem = euclidean_dist_line.new_zeros(
num_proposals * 6, dtype=torch.long)
objectness_label_line_sem = euclidean_dist_line.new_zeros(
num_proposals * 12, dtype=torch.long)
euclidean_dist_obj_surface = torch.sqrt((
(pred_obj_surface_center - surface_sel)**2).sum(dim=-1) + 1e-6)
euclidean_dist_obj_line = torch.sqrt(
torch.sum((pred_obj_line_center - line_sel)**2, dim=-1) + 1e-6)
# Objectness score just with centers
proposal_objectness_label[
euclidean_dist1 < self.train_cfg['near_threshold']] = 1
proposal_objectness_mask[
euclidean_dist1 < self.train_cfg['near_threshold']] = 1
proposal_objectness_mask[
euclidean_dist1 > self.train_cfg['far_threshold']] = 1
objectness_label_surface[
(euclidean_dist_obj_surface <
self.train_cfg['label_surface_threshold']) *
(euclidean_dist_surface <
self.train_cfg['mask_surface_threshold'])] = 1
objectness_label_surface_sem[
(euclidean_dist_obj_surface <
self.train_cfg['label_surface_threshold']) *
(euclidean_dist_surface < self.train_cfg['mask_surface_threshold'])
* (surface_sel_sem == surface_sel_sem_gt)] = 1
objectness_label_line[
(euclidean_dist_obj_line < self.train_cfg['label_line_threshold'])
*
(euclidean_dist_line < self.train_cfg['mask_line_threshold'])] = 1
objectness_label_line_sem[
(euclidean_dist_obj_line < self.train_cfg['label_line_threshold'])
* (euclidean_dist_line < self.train_cfg['mask_line_threshold']) *
(line_sel_sem == line_sel_sem_gt)] = 1
objectness_label_surface_obj = proposal_objectness_label.repeat(6)
objectness_mask_surface_obj = proposal_objectness_mask.repeat(6)
objectness_label_line_obj = proposal_objectness_label.repeat(12)
objectness_mask_line_obj = proposal_objectness_mask.repeat(12)
objectness_mask_surface = objectness_mask_surface_obj
objectness_mask_line = objectness_mask_line_obj
cues_objectness_label = torch.cat(
(objectness_label_surface, objectness_label_line), 0)
cues_sem_label = torch.cat(
(objectness_label_surface_sem, objectness_label_line_sem), 0)
cues_mask = torch.cat((objectness_mask_surface, objectness_mask_line),
0)
objectness_label_surface *= objectness_label_surface_obj
objectness_label_line *= objectness_label_line_obj
cues_matching_label = torch.cat(
(objectness_label_surface, objectness_label_line), 0)
objectness_label_surface_sem *= objectness_label_surface_obj
objectness_label_line_sem *= objectness_label_line_obj
cues_match_mask = (torch.sum(
cues_objectness_label.view(18, num_proposals), dim=0) >=
1).float()
obj_surface_line_center = torch.cat(
(obj_surface_center, obj_line_center), 1).squeeze(0)
return (cues_objectness_label, cues_sem_label,
proposal_objectness_label, cues_mask, cues_match_mask,
proposal_objectness_mask, cues_matching_label,
obj_surface_line_center)