import torch
from mmcv.runner import BaseModule
from torch import nn as nn
from torch.nn import functional as F
from mmdet3d.core.bbox.structures import rotation_3d_in_axis
from mmdet3d.models.builder import build_loss
from mmdet.core import multi_apply
from mmdet.models import HEADS
[docs]@HEADS.register_module()
class PointwiseSemanticHead(BaseModule):
"""Semantic segmentation head for point-wise segmentation.
Predict point-wise segmentation and part regression results for PartA2.
See `paper <https://arxiv.org/abs/1907.03670>`_ for more details.
Args:
in_channels (int): The number of input channel.
num_classes (int): The number of class.
extra_width (float): Boxes enlarge width.
loss_seg (dict): Config of segmentation loss.
loss_part (dict): Config of part prediction loss.
"""
def __init__(self,
in_channels,
num_classes=3,
extra_width=0.2,
seg_score_thr=0.3,
init_cfg=None,
loss_seg=dict(
type='FocalLoss',
use_sigmoid=True,
reduction='sum',
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_part=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0)):
super(PointwiseSemanticHead, self).__init__(init_cfg=init_cfg)
self.extra_width = extra_width
self.num_classes = num_classes
self.seg_score_thr = seg_score_thr
self.seg_cls_layer = nn.Linear(in_channels, 1, bias=True)
self.seg_reg_layer = nn.Linear(in_channels, 3, bias=True)
self.loss_seg = build_loss(loss_seg)
self.loss_part = build_loss(loss_part)
[docs] def forward(self, x):
"""Forward pass.
Args:
x (torch.Tensor): Features from the first stage.
Returns:
dict: Part features, segmentation and part predictions.
- seg_preds (torch.Tensor): Segment predictions.
- part_preds (torch.Tensor): Part predictions.
- part_feats (torch.Tensor): Feature predictions.
"""
seg_preds = self.seg_cls_layer(x) # (N, 1)
part_preds = self.seg_reg_layer(x) # (N, 3)
seg_scores = torch.sigmoid(seg_preds).detach()
seg_mask = (seg_scores > self.seg_score_thr)
part_offsets = torch.sigmoid(part_preds).clone().detach()
part_offsets[seg_mask.view(-1) == 0] = 0
part_feats = torch.cat((part_offsets, seg_scores),
dim=-1) # shape (npoints, 4)
return dict(
seg_preds=seg_preds, part_preds=part_preds, part_feats=part_feats)
[docs] def get_targets_single(self, voxel_centers, gt_bboxes_3d, gt_labels_3d):
"""generate segmentation and part prediction targets for a single
sample.
Args:
voxel_centers (torch.Tensor): The center of voxels in shape \
(voxel_num, 3).
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth boxes in \
shape (box_num, 7).
gt_labels_3d (torch.Tensor): Class labels of ground truths in \
shape (box_num).
Returns:
tuple[torch.Tensor]: Segmentation targets with shape [voxel_num] \
part prediction targets with shape [voxel_num, 3]
"""
gt_bboxes_3d = gt_bboxes_3d.to(voxel_centers.device)
enlarged_gt_boxes = gt_bboxes_3d.enlarged_box(self.extra_width)
part_targets = voxel_centers.new_zeros((voxel_centers.shape[0], 3),
dtype=torch.float32)
box_idx = gt_bboxes_3d.points_in_boxes(voxel_centers)
enlarge_box_idx = enlarged_gt_boxes.points_in_boxes(
voxel_centers).long()
gt_labels_pad = F.pad(
gt_labels_3d, (1, 0), mode='constant', value=self.num_classes)
seg_targets = gt_labels_pad[(box_idx.long() + 1)]
fg_pt_flag = box_idx > -1
ignore_flag = fg_pt_flag ^ (enlarge_box_idx > -1)
seg_targets[ignore_flag] = -1
for k in range(len(gt_bboxes_3d)):
k_box_flag = box_idx == k
# no point in current box (caused by velodyne reduce)
if not k_box_flag.any():
continue
fg_voxels = voxel_centers[k_box_flag]
transformed_voxels = fg_voxels - gt_bboxes_3d.bottom_center[k]
transformed_voxels = rotation_3d_in_axis(
transformed_voxels.unsqueeze(0),
-gt_bboxes_3d.yaw[k].view(1),
axis=2)
part_targets[k_box_flag] = transformed_voxels / gt_bboxes_3d.dims[
k] + voxel_centers.new_tensor([0.5, 0.5, 0])
part_targets = torch.clamp(part_targets, min=0)
return seg_targets, part_targets
[docs] def get_targets(self, voxels_dict, gt_bboxes_3d, gt_labels_3d):
"""generate segmentation and part prediction targets.
Args:
voxel_centers (torch.Tensor): The center of voxels in shape \
(voxel_num, 3).
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth boxes in \
shape (box_num, 7).
gt_labels_3d (torch.Tensor): Class labels of ground truths in \
shape (box_num).
Returns:
dict: Prediction targets
- seg_targets (torch.Tensor): Segmentation targets \
with shape [voxel_num].
- part_targets (torch.Tensor): Part prediction targets \
with shape [voxel_num, 3].
"""
batch_size = len(gt_labels_3d)
voxel_center_list = []
for idx in range(batch_size):
coords_idx = voxels_dict['coors'][:, 0] == idx
voxel_center_list.append(voxels_dict['voxel_centers'][coords_idx])
seg_targets, part_targets = multi_apply(self.get_targets_single,
voxel_center_list,
gt_bboxes_3d, gt_labels_3d)
seg_targets = torch.cat(seg_targets, dim=0)
part_targets = torch.cat(part_targets, dim=0)
return dict(seg_targets=seg_targets, part_targets=part_targets)
[docs] def loss(self, semantic_results, semantic_targets):
"""Calculate point-wise segmentation and part prediction losses.
Args:
semantic_results (dict): Results from semantic head.
- seg_preds: Segmentation predictions.
- part_preds: Part predictions.
semantic_targets (dict): Targets of semantic results.
- seg_preds: Segmentation targets.
- part_preds: Part targets.
Returns:
dict: Loss of segmentation and part prediction.
- loss_seg (torch.Tensor): Segmentation prediction loss.
- loss_part (torch.Tensor): Part prediction loss.
"""
seg_preds = semantic_results['seg_preds']
part_preds = semantic_results['part_preds']
seg_targets = semantic_targets['seg_targets']
part_targets = semantic_targets['part_targets']
pos_mask = (seg_targets > -1) & (seg_targets < self.num_classes)
binary_seg_target = pos_mask.long()
pos = pos_mask.float()
neg = (seg_targets == self.num_classes).float()
seg_weights = pos + neg
pos_normalizer = pos.sum()
seg_weights = seg_weights / torch.clamp(pos_normalizer, min=1.0)
loss_seg = self.loss_seg(seg_preds, binary_seg_target, seg_weights)
if pos_normalizer > 0:
loss_part = self.loss_part(part_preds[pos_mask],
part_targets[pos_mask])
else:
# fake a part loss
loss_part = loss_seg.new_tensor(0)
return dict(loss_seg=loss_seg, loss_part=loss_part)