Shortcuts

Source code for mmdet3d.models.dense_heads.imvoxel_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.cnn import Scale, bias_init_with_prob, normal_init
from mmcv.ops import nms3d, nms3d_normal
from mmcv.runner import BaseModule
from torch import nn

from mmdet3d.core import build_prior_generator
from mmdet3d.core.bbox.structures import rotation_3d_in_axis
from mmdet.core import multi_apply, reduce_mean
from ..builder import HEADS, build_loss


[docs]@HEADS.register_module() class ImVoxelHead(BaseModule): r"""`ImVoxelNet<https://arxiv.org/abs/2106.01178>`_ head for indoor datasets. Args: n_classes (int): Number of classes. n_levels (int): Number of feature levels. n_channels (int): Number of channels in input tensors. n_reg_outs (int): Number of regression layer channels. pts_assign_threshold (int): Min number of location per box to be assigned with. pts_center_threshold (int): Max number of locations per box to be assigned with. center_loss (dict, optional): Config of centerness loss. Default: dict(type='CrossEntropyLoss', use_sigmoid=True). bbox_loss (dict, optional): Config of bbox loss. Default: dict(type='RotatedIoU3DLoss'). cls_loss (dict, optional): Config of classification loss. Default: dict(type='FocalLoss'). train_cfg (dict, optional): Config for train stage. Defaults to None. test_cfg (dict, optional): Config for test stage. Defaults to None. init_cfg (dict, optional): Config for weight initialization. Defaults to None. """ def __init__(self, n_classes, n_levels, n_channels, n_reg_outs, pts_assign_threshold, pts_center_threshold, prior_generator, center_loss=dict(type='CrossEntropyLoss', use_sigmoid=True), bbox_loss=dict(type='RotatedIoU3DLoss'), cls_loss=dict(type='FocalLoss'), train_cfg=None, test_cfg=None, init_cfg=None): super(ImVoxelHead, self).__init__(init_cfg) self.pts_assign_threshold = pts_assign_threshold self.pts_center_threshold = pts_center_threshold self.prior_generator = build_prior_generator(prior_generator) self.center_loss = build_loss(center_loss) self.bbox_loss = build_loss(bbox_loss) self.cls_loss = build_loss(cls_loss) self.train_cfg = train_cfg self.test_cfg = test_cfg self._init_layers(n_channels, n_reg_outs, n_classes, n_levels) def _init_layers(self, n_channels, n_reg_outs, n_classes, n_levels): """Initialize neural network layers of the head.""" self.conv_center = nn.Conv3d(n_channels, 1, 3, padding=1, bias=False) self.conv_reg = nn.Conv3d( n_channels, n_reg_outs, 3, padding=1, bias=False) self.conv_cls = nn.Conv3d(n_channels, n_classes, 3, padding=1) self.scales = nn.ModuleList([Scale(1.) for _ in range(n_levels)])
[docs] def init_weights(self): """Initialize all layer weights.""" normal_init(self.conv_center, std=.01) normal_init(self.conv_reg, std=.01) normal_init(self.conv_cls, std=.01, bias=bias_init_with_prob(.01))
def _forward_single(self, x, scale): """Forward pass per level. Args: x (Tensor): Per level 3d neck output tensor. scale (mmcv.cnn.Scale): Per level multiplication weight. Returns: tuple[Tensor]: Centerness, bbox and classification predictions. """ reg_final = self.conv_reg(x) reg_distance = torch.exp(scale(reg_final[:, :6])) reg_angle = reg_final[:, 6:] bbox_pred = torch.cat((reg_distance, reg_angle), dim=1) return self.conv_center(x), bbox_pred, self.conv_cls(x)
[docs] def forward(self, x): """Forward function. Args: x (list[Tensor]): Features from 3d neck. Returns: tuple[Tensor]: Centerness, bbox and classification predictions. """ return multi_apply(self._forward_single, x, self.scales)
def _loss_single(self, center_preds, bbox_preds, cls_preds, valid_preds, img_meta, gt_bboxes, gt_labels): """Per scene loss function. Args: center_preds (list[Tensor]): Centerness predictions for all levels. bbox_preds (list[Tensor]): Bbox predictions for all levels. cls_preds (list[Tensor]): Classification predictions for all levels. valid_preds (list[Tensor]): Valid mask predictions for all levels. img_meta (dict): Scene meta info. gt_bboxes (BaseInstance3DBoxes): Ground truth boxes. gt_labels (Tensor): Ground truth labels. Returns: tuple[Tensor]: Centerness, bbox, and classification loss values. """ points = self._get_points(center_preds) center_targets, bbox_targets, cls_targets = self._get_targets( points, gt_bboxes, gt_labels) center_preds = torch.cat( [x.permute(1, 2, 3, 0).reshape(-1) for x in center_preds]) bbox_preds = torch.cat([ x.permute(1, 2, 3, 0).reshape(-1, x.shape[0]) for x in bbox_preds ]) cls_preds = torch.cat( [x.permute(1, 2, 3, 0).reshape(-1, x.shape[0]) for x in cls_preds]) valid_preds = torch.cat( [x.permute(1, 2, 3, 0).reshape(-1) for x in valid_preds]) points = torch.cat(points) # cls loss pos_inds = torch.nonzero( torch.logical_and(cls_targets >= 0, valid_preds)).squeeze(1) n_pos = points.new_tensor(len(pos_inds)) n_pos = max(reduce_mean(n_pos), 1.) if torch.any(valid_preds): cls_loss = self.cls_loss( cls_preds[valid_preds], cls_targets[valid_preds], avg_factor=n_pos) else: cls_loss = cls_preds[valid_preds].sum() # bbox and centerness losses pos_center_preds = center_preds[pos_inds] pos_bbox_preds = bbox_preds[pos_inds] if len(pos_inds) > 0: pos_center_targets = center_targets[pos_inds] pos_bbox_targets = bbox_targets[pos_inds] pos_points = points[pos_inds] center_loss = self.center_loss( pos_center_preds, pos_center_targets, avg_factor=n_pos) bbox_loss = self.bbox_loss( self._bbox_pred_to_bbox(pos_points, pos_bbox_preds), pos_bbox_targets, weight=pos_center_targets, avg_factor=pos_center_targets.sum()) else: center_loss = pos_center_preds.sum() bbox_loss = pos_bbox_preds.sum() return center_loss, bbox_loss, cls_loss
[docs] def loss(self, center_preds, bbox_preds, cls_preds, valid_pred, gt_bboxes, gt_labels, img_metas): """Per scene loss function. Args: center_preds (list[list[Tensor]]): Centerness predictions for all scenes. bbox_preds (list[list[Tensor]]): Bbox predictions for all scenes. cls_preds (list[list[Tensor]]): Classification predictions for all scenes. valid_pred (Tensor): Valid mask prediction for all scenes. gt_bboxes (list[BaseInstance3DBoxes]): Ground truth boxes for all scenes. gt_labels (list[Tensor]): Ground truth labels for all scenes. img_metas (list[dict]): Meta infos for all scenes. Returns: dict: Centerness, bbox, and classification loss values. """ valid_preds = self._upsample_valid_preds(valid_pred, center_preds) center_losses, bbox_losses, cls_losses = [], [], [] for i in range(len(img_metas)): center_loss, bbox_loss, cls_loss = self._loss_single( center_preds=[x[i] for x in center_preds], bbox_preds=[x[i] for x in bbox_preds], cls_preds=[x[i] for x in cls_preds], valid_preds=[x[i] for x in valid_preds], img_meta=img_metas[i], gt_bboxes=gt_bboxes[i], gt_labels=gt_labels[i]) center_losses.append(center_loss) bbox_losses.append(bbox_loss) cls_losses.append(cls_loss) return dict( center_loss=torch.mean(torch.stack(center_losses)), bbox_loss=torch.mean(torch.stack(bbox_losses)), cls_loss=torch.mean(torch.stack(cls_losses)))
def _get_bboxes_single(self, center_preds, bbox_preds, cls_preds, valid_preds, img_meta): """Generate boxes for a single scene. Args: center_preds (list[Tensor]): Centerness predictions for all levels. bbox_preds (list[Tensor]): Bbox predictions for all levels. cls_preds (list[Tensor]): Classification predictions for all levels. valid_preds (list[Tensor]): Valid mask predictions for all levels. img_meta (dict): Scene meta info. Returns: tuple[Tensor]: Predicted bounding boxes, scores and labels. """ points = self._get_points(center_preds) mlvl_bboxes, mlvl_scores = [], [] for center_pred, bbox_pred, cls_pred, valid_pred, point in zip( center_preds, bbox_preds, cls_preds, valid_preds, points): center_pred = center_pred.permute(1, 2, 3, 0).reshape(-1, 1) bbox_pred = bbox_pred.permute(1, 2, 3, 0).reshape(-1, bbox_pred.shape[0]) cls_pred = cls_pred.permute(1, 2, 3, 0).reshape(-1, cls_pred.shape[0]) valid_pred = valid_pred.permute(1, 2, 3, 0).reshape(-1, 1) scores = cls_pred.sigmoid() * center_pred.sigmoid() * valid_pred max_scores, _ = scores.max(dim=1) if len(scores) > self.test_cfg.nms_pre > 0: _, ids = max_scores.topk(self.test_cfg.nms_pre) bbox_pred = bbox_pred[ids] scores = scores[ids] point = point[ids] bboxes = self._bbox_pred_to_bbox(point, bbox_pred) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) bboxes = torch.cat(mlvl_bboxes) scores = torch.cat(mlvl_scores) bboxes, scores, labels = self._single_scene_multiclass_nms( bboxes, scores, img_meta) return bboxes, scores, labels
[docs] def get_bboxes(self, center_preds, bbox_preds, cls_preds, valid_pred, img_metas): """Generate boxes for all scenes. Args: center_preds (list[list[Tensor]]): Centerness predictions for all scenes. bbox_preds (list[list[Tensor]]): Bbox predictions for all scenes. cls_preds (list[list[Tensor]]): Classification predictions for all scenes. valid_pred (Tensor): Valid mask prediction for all scenes. img_metas (list[dict]): Meta infos for all scenes. Returns: list[tuple[Tensor]]: Predicted bboxes, scores, and labels for all scenes. """ valid_preds = self._upsample_valid_preds(valid_pred, center_preds) results = [] for i in range(len(img_metas)): results.append( self._get_bboxes_single( center_preds=[x[i] for x in center_preds], bbox_preds=[x[i] for x in bbox_preds], cls_preds=[x[i] for x in cls_preds], valid_preds=[x[i] for x in valid_preds], img_meta=img_metas[i])) return results
@staticmethod def _upsample_valid_preds(valid_pred, features): """Upsample valid mask predictions. Args: valid_pred (Tensor): Valid mask prediction. features (Tensor): Feature tensor. Returns: tuple[Tensor]: Upsampled valid masks for all feature levels. """ return [ nn.Upsample(size=x.shape[-3:], mode='trilinear')(valid_pred).round().bool() for x in features ] def _get_points(self, features): """Generate final locations. Args: features (list[Tensor]): Feature tensors for all feature levels. Returns: list(Tensor): Final locations for all feature levels. """ points = [] for x in features: n_voxels = x.size()[-3:][::-1] points.append( self.prior_generator.grid_anchors( [n_voxels], device=x.device)[0][:, :3].reshape(n_voxels + (3, )).permute( 2, 1, 0, 3).reshape(-1, 3)) return points @staticmethod def _bbox_pred_to_bbox(points, bbox_pred): """Transform predicted bbox parameters to bbox. Args: points (Tensor): Final locations of shape (N, 3). bbox_pred (Tensor): Predicted bbox parameters of shape (N, 7). Returns: Tensor: Transformed 3D box of shape (N, 7). """ if bbox_pred.shape[0] == 0: return bbox_pred # dx_min, dx_max, dy_min, dy_max, dz_min, dz_max, alpha -> # x_center, y_center, z_center, w, l, h, alpha shift = torch.stack(((bbox_pred[:, 1] - bbox_pred[:, 0]) / 2, (bbox_pred[:, 3] - bbox_pred[:, 2]) / 2, (bbox_pred[:, 5] - bbox_pred[:, 4]) / 2), dim=-1).view(-1, 1, 3) shift = rotation_3d_in_axis(shift, bbox_pred[:, 6], axis=2)[:, 0, :] center = points + shift size = torch.stack( (bbox_pred[:, 0] + bbox_pred[:, 1], bbox_pred[:, 2] + bbox_pred[:, 3], bbox_pred[:, 4] + bbox_pred[:, 5]), dim=-1) return torch.cat((center, size, bbox_pred[:, 6:7]), dim=-1) # The function is directly copied from FCAF3DHead. @staticmethod def _get_face_distances(points, boxes): """Calculate distances from point to box faces. Args: points (Tensor): Final locations of shape (N_points, N_boxes, 3). boxes (Tensor): 3D boxes of shape (N_points, N_boxes, 7) Returns: Tensor: Face distances of shape (N_points, N_boxes, 6), (dx_min, dx_max, dy_min, dy_max, dz_min, dz_max). """ shift = torch.stack( (points[..., 0] - boxes[..., 0], points[..., 1] - boxes[..., 1], points[..., 2] - boxes[..., 2]), dim=-1).permute(1, 0, 2) shift = rotation_3d_in_axis( shift, -boxes[0, :, 6], axis=2).permute(1, 0, 2) centers = boxes[..., :3] + shift dx_min = centers[..., 0] - boxes[..., 0] + boxes[..., 3] / 2 dx_max = boxes[..., 0] + boxes[..., 3] / 2 - centers[..., 0] dy_min = centers[..., 1] - boxes[..., 1] + boxes[..., 4] / 2 dy_max = boxes[..., 1] + boxes[..., 4] / 2 - centers[..., 1] dz_min = centers[..., 2] - boxes[..., 2] + boxes[..., 5] / 2 dz_max = boxes[..., 2] + boxes[..., 5] / 2 - centers[..., 2] return torch.stack((dx_min, dx_max, dy_min, dy_max, dz_min, dz_max), dim=-1) # The function is directly copied from FCAF3DHead. @staticmethod def _get_centerness(face_distances): """Compute point centerness w.r.t containing box. Args: face_distances (Tensor): Face distances of shape (B, N, 6), (dx_min, dx_max, dy_min, dy_max, dz_min, dz_max). Returns: Tensor: Centerness of shape (B, N). """ x_dims = face_distances[..., [0, 1]] y_dims = face_distances[..., [2, 3]] z_dims = face_distances[..., [4, 5]] centerness_targets = x_dims.min(dim=-1)[0] / x_dims.max(dim=-1)[0] * \ y_dims.min(dim=-1)[0] / y_dims.max(dim=-1)[0] * \ z_dims.min(dim=-1)[0] / z_dims.max(dim=-1)[0] return torch.sqrt(centerness_targets) # The function is directly copied from FCAF3DHead. @torch.no_grad() def _get_targets(self, points, gt_bboxes, gt_labels): """Compute targets for final locations for a single scene. Args: points (list[Tensor]): Final locations for all levels. gt_bboxes (BaseInstance3DBoxes): Ground truth boxes. gt_labels (Tensor): Ground truth labels. Returns: tuple[Tensor]: Centerness, bbox and classification targets for all locations. """ float_max = points[0].new_tensor(1e8) n_levels = len(points) levels = torch.cat([ points[i].new_tensor(i).expand(len(points[i])) for i in range(len(points)) ]) points = torch.cat(points) gt_bboxes = gt_bboxes.to(points.device) n_points = len(points) n_boxes = len(gt_bboxes) volumes = gt_bboxes.volume.unsqueeze(0).expand(n_points, n_boxes) # condition 1: point inside box boxes = torch.cat((gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:]), dim=1) boxes = boxes.expand(n_points, n_boxes, 7) points = points.unsqueeze(1).expand(n_points, n_boxes, 3) face_distances = self._get_face_distances(points, boxes) inside_box_condition = face_distances.min(dim=-1).values > 0 # condition 2: positive points per level >= limit # calculate positive points per scale n_pos_points_per_level = [] for i in range(n_levels): n_pos_points_per_level.append( torch.sum(inside_box_condition[levels == i], dim=0)) # find best level n_pos_points_per_level = torch.stack(n_pos_points_per_level, dim=0) lower_limit_mask = n_pos_points_per_level < self.pts_assign_threshold lower_index = torch.argmax(lower_limit_mask.int(), dim=0) - 1 lower_index = torch.where(lower_index < 0, 0, lower_index) all_upper_limit_mask = torch.all( torch.logical_not(lower_limit_mask), dim=0) best_level = torch.where(all_upper_limit_mask, n_levels - 1, lower_index) # keep only points with best level best_level = best_level.expand(n_points, n_boxes) levels = torch.unsqueeze(levels, 1).expand(n_points, n_boxes) level_condition = best_level == levels # condition 3: limit topk points per box by centerness centerness = self._get_centerness(face_distances) centerness = torch.where(inside_box_condition, centerness, torch.ones_like(centerness) * -1) centerness = torch.where(level_condition, centerness, torch.ones_like(centerness) * -1) top_centerness = torch.topk( centerness, min(self.pts_center_threshold + 1, len(centerness)), dim=0).values[-1] topk_condition = centerness > top_centerness.unsqueeze(0) # condition 4: min volume box per point volumes = torch.where(inside_box_condition, volumes, float_max) volumes = torch.where(level_condition, volumes, float_max) volumes = torch.where(topk_condition, volumes, float_max) min_volumes, min_inds = volumes.min(dim=1) center_targets = centerness[torch.arange(n_points), min_inds] bbox_targets = boxes[torch.arange(n_points), min_inds] if not gt_bboxes.with_yaw: bbox_targets = bbox_targets[:, :-1] cls_targets = gt_labels[min_inds] cls_targets = torch.where(min_volumes == float_max, -1, cls_targets) return center_targets, bbox_targets, cls_targets # Originally ImVoxelNet utilizes 2d nms as mmdetection3d didn't # support 3d nms. But since mmcv==1.5.2 we simply use nms3d here. # The function is directly copied from FCAF3DHead. def _single_scene_multiclass_nms(self, bboxes, scores, input_meta): """Multi-class nms for a single scene. Args: bboxes (Tensor): Predicted boxes of shape (N_boxes, 6) or (N_boxes, 7). scores (Tensor): Predicted scores of shape (N_boxes, N_classes). input_meta (dict): Scene meta data. Returns: tuple[Tensor]: Predicted bboxes, scores and labels. """ n_classes = scores.shape[1] with_yaw = bboxes.shape[1] == 7 nms_bboxes, nms_scores, nms_labels = [], [], [] for i in range(n_classes): ids = scores[:, i] > self.test_cfg.score_thr if not ids.any(): continue class_scores = scores[ids, i] class_bboxes = bboxes[ids] if with_yaw: nms_function = nms3d else: class_bboxes = torch.cat( (class_bboxes, torch.zeros_like(class_bboxes[:, :1])), dim=1) nms_function = nms3d_normal nms_ids = nms_function(class_bboxes, class_scores, self.test_cfg.iou_thr) nms_bboxes.append(class_bboxes[nms_ids]) nms_scores.append(class_scores[nms_ids]) nms_labels.append( bboxes.new_full( class_scores[nms_ids].shape, i, dtype=torch.long)) if len(nms_bboxes): nms_bboxes = torch.cat(nms_bboxes, dim=0) nms_scores = torch.cat(nms_scores, dim=0) nms_labels = torch.cat(nms_labels, dim=0) else: nms_bboxes = bboxes.new_zeros((0, bboxes.shape[1])) nms_scores = bboxes.new_zeros((0, )) nms_labels = bboxes.new_zeros((0, )) if with_yaw: box_dim = 7 else: box_dim = 6 nms_bboxes = nms_bboxes[:, :6] nms_bboxes = input_meta['box_type_3d']( nms_bboxes, box_dim=box_dim, with_yaw=with_yaw, origin=(.5, .5, .5)) return nms_bboxes, nms_scores, nms_labels
Read the Docs v: dev
Versions
latest
stable
v1.0.0rc1
v1.0.0rc0
v0.18.1
v0.18.0
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.