Source code for mmdet3d.models.roi_heads.mask_heads.pointwise_semantic_head

import torch
from mmcv.runner import BaseModule
from torch import nn as nn
from torch.nn import functional as F

from mmdet3d.core.bbox.structures import rotation_3d_in_axis
from mmdet3d.models.builder import build_loss
from mmdet.core import multi_apply
from mmdet.models import HEADS


[docs]@HEADS.register_module() class PointwiseSemanticHead(BaseModule): """Semantic segmentation head for point-wise segmentation. Predict point-wise segmentation and part regression results for PartA2. See `paper <https://arxiv.org/abs/1907.03670>`_ for more details. Args: in_channels (int): The number of input channel. num_classes (int): The number of class. extra_width (float): Boxes enlarge width. loss_seg (dict): Config of segmentation loss. loss_part (dict): Config of part prediction loss. """ def __init__(self, in_channels, num_classes=3, extra_width=0.2, seg_score_thr=0.3, init_cfg=None, loss_seg=dict( type='FocalLoss', use_sigmoid=True, reduction='sum', gamma=2.0, alpha=0.25, loss_weight=1.0), loss_part=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)): super(PointwiseSemanticHead, self).__init__(init_cfg=init_cfg) self.extra_width = extra_width self.num_classes = num_classes self.seg_score_thr = seg_score_thr self.seg_cls_layer = nn.Linear(in_channels, 1, bias=True) self.seg_reg_layer = nn.Linear(in_channels, 3, bias=True) self.loss_seg = build_loss(loss_seg) self.loss_part = build_loss(loss_part)
[docs] def forward(self, x): """Forward pass. Args: x (torch.Tensor): Features from the first stage. Returns: dict: Part features, segmentation and part predictions. - seg_preds (torch.Tensor): Segment predictions. - part_preds (torch.Tensor): Part predictions. - part_feats (torch.Tensor): Feature predictions. """ seg_preds = self.seg_cls_layer(x) # (N, 1) part_preds = self.seg_reg_layer(x) # (N, 3) seg_scores = torch.sigmoid(seg_preds).detach() seg_mask = (seg_scores > self.seg_score_thr) part_offsets = torch.sigmoid(part_preds).clone().detach() part_offsets[seg_mask.view(-1) == 0] = 0 part_feats = torch.cat((part_offsets, seg_scores), dim=-1) # shape (npoints, 4) return dict( seg_preds=seg_preds, part_preds=part_preds, part_feats=part_feats)
[docs] def get_targets_single(self, voxel_centers, gt_bboxes_3d, gt_labels_3d): """generate segmentation and part prediction targets for a single sample. Args: voxel_centers (torch.Tensor): The center of voxels in shape \ (voxel_num, 3). gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth boxes in \ shape (box_num, 7). gt_labels_3d (torch.Tensor): Class labels of ground truths in \ shape (box_num). Returns: tuple[torch.Tensor]: Segmentation targets with shape [voxel_num] \ part prediction targets with shape [voxel_num, 3] """ gt_bboxes_3d = gt_bboxes_3d.to(voxel_centers.device) enlarged_gt_boxes = gt_bboxes_3d.enlarged_box(self.extra_width) part_targets = voxel_centers.new_zeros((voxel_centers.shape[0], 3), dtype=torch.float32) box_idx = gt_bboxes_3d.points_in_boxes(voxel_centers) enlarge_box_idx = enlarged_gt_boxes.points_in_boxes( voxel_centers).long() gt_labels_pad = F.pad( gt_labels_3d, (1, 0), mode='constant', value=self.num_classes) seg_targets = gt_labels_pad[(box_idx.long() + 1)] fg_pt_flag = box_idx > -1 ignore_flag = fg_pt_flag ^ (enlarge_box_idx > -1) seg_targets[ignore_flag] = -1 for k in range(len(gt_bboxes_3d)): k_box_flag = box_idx == k # no point in current box (caused by velodyne reduce) if not k_box_flag.any(): continue fg_voxels = voxel_centers[k_box_flag] transformed_voxels = fg_voxels - gt_bboxes_3d.bottom_center[k] transformed_voxels = rotation_3d_in_axis( transformed_voxels.unsqueeze(0), -gt_bboxes_3d.yaw[k].view(1), axis=2) part_targets[k_box_flag] = transformed_voxels / gt_bboxes_3d.dims[ k] + voxel_centers.new_tensor([0.5, 0.5, 0]) part_targets = torch.clamp(part_targets, min=0) return seg_targets, part_targets
[docs] def get_targets(self, voxels_dict, gt_bboxes_3d, gt_labels_3d): """generate segmentation and part prediction targets. Args: voxel_centers (torch.Tensor): The center of voxels in shape \ (voxel_num, 3). gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth boxes in \ shape (box_num, 7). gt_labels_3d (torch.Tensor): Class labels of ground truths in \ shape (box_num). Returns: dict: Prediction targets - seg_targets (torch.Tensor): Segmentation targets \ with shape [voxel_num]. - part_targets (torch.Tensor): Part prediction targets \ with shape [voxel_num, 3]. """ batch_size = len(gt_labels_3d) voxel_center_list = [] for idx in range(batch_size): coords_idx = voxels_dict['coors'][:, 0] == idx voxel_center_list.append(voxels_dict['voxel_centers'][coords_idx]) seg_targets, part_targets = multi_apply(self.get_targets_single, voxel_center_list, gt_bboxes_3d, gt_labels_3d) seg_targets = torch.cat(seg_targets, dim=0) part_targets = torch.cat(part_targets, dim=0) return dict(seg_targets=seg_targets, part_targets=part_targets)
[docs] def loss(self, semantic_results, semantic_targets): """Calculate point-wise segmentation and part prediction losses. Args: semantic_results (dict): Results from semantic head. - seg_preds: Segmentation predictions. - part_preds: Part predictions. semantic_targets (dict): Targets of semantic results. - seg_preds: Segmentation targets. - part_preds: Part targets. Returns: dict: Loss of segmentation and part prediction. - loss_seg (torch.Tensor): Segmentation prediction loss. - loss_part (torch.Tensor): Part prediction loss. """ seg_preds = semantic_results['seg_preds'] part_preds = semantic_results['part_preds'] seg_targets = semantic_targets['seg_targets'] part_targets = semantic_targets['part_targets'] pos_mask = (seg_targets > -1) & (seg_targets < self.num_classes) binary_seg_target = pos_mask.long() pos = pos_mask.float() neg = (seg_targets == self.num_classes).float() seg_weights = pos + neg pos_normalizer = pos.sum() seg_weights = seg_weights / torch.clamp(pos_normalizer, min=1.0) loss_seg = self.loss_seg(seg_preds, binary_seg_target, seg_weights) if pos_normalizer > 0: loss_part = self.loss_part(part_preds[pos_mask], part_targets[pos_mask]) else: # fake a part loss loss_part = loss_seg.new_tensor(0) return dict(loss_seg=loss_seg, loss_part=loss_part)