import numpy as np
import torch
from mmcv.cnn import Scale
from mmcv.runner import force_fp32
from torch import nn as nn
from mmdet3d.core import box3d_multiclass_nms, limit_period, xywhr2xyxyr
from mmdet.core import multi_apply
from mmdet.models.builder import HEADS, build_loss
from .anchor_free_mono3d_head import AnchorFreeMono3DHead
INF = 1e8
[docs]@HEADS.register_module()
class FCOSMono3DHead(AnchorFreeMono3DHead):
"""Anchor-free head used in FCOS3D.
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
regress_ranges (tuple[tuple[int, int]]): Regress range of multiple
level points.
center_sampling (bool): If true, use center sampling. Default: True.
center_sample_radius (float): Radius of center sampling. Default: 1.5.
norm_on_bbox (bool): If true, normalize the regression targets
with FPN strides. Default: True.
centerness_on_reg (bool): If true, position centerness on the
regress branch. Please refer to https://github.com/tianzhi0549/FCOS/issues/89#issuecomment-516877042.
Default: True.
centerness_alpha: Parameter used to adjust the intensity attenuation
from the center to the periphery. Default: 2.5.
loss_cls (dict): Config of classification loss.
loss_bbox (dict): Config of localization loss.
loss_dir (dict): Config of direction classification loss.
loss_attr (dict): Config of attribute classification loss.
loss_centerness (dict): Config of centerness loss.
norm_cfg (dict): dictionary to construct and config norm layer.
Default: norm_cfg=dict(type='GN', num_groups=32, requires_grad=True).
centerness_branch (tuple[int]): Channels for centerness branch.
Default: (64, ).
""" # noqa: E501
def __init__(self,
num_classes,
in_channels,
regress_ranges=((-1, 48), (48, 96), (96, 192), (192, 384),
(384, INF)),
center_sampling=True,
center_sample_radius=1.5,
norm_on_bbox=True,
centerness_on_reg=True,
centerness_alpha=2.5,
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
loss_dir=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_attr=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_centerness=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0),
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
centerness_branch=(64, ),
init_cfg=None,
**kwargs):
self.regress_ranges = regress_ranges
self.center_sampling = center_sampling
self.center_sample_radius = center_sample_radius
self.norm_on_bbox = norm_on_bbox
self.centerness_on_reg = centerness_on_reg
self.centerness_alpha = centerness_alpha
self.centerness_branch = centerness_branch
super().__init__(
num_classes,
in_channels,
loss_cls=loss_cls,
loss_bbox=loss_bbox,
loss_dir=loss_dir,
loss_attr=loss_attr,
norm_cfg=norm_cfg,
init_cfg=init_cfg,
**kwargs)
self.loss_centerness = build_loss(loss_centerness)
if init_cfg is None:
self.init_cfg = dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal', name='conv_cls', std=0.01, bias_prob=0.01))
def _init_layers(self):
"""Initialize layers of the head."""
super()._init_layers()
self.conv_centerness_prev = self._init_branch(
conv_channels=self.centerness_branch,
conv_strides=(1, ) * len(self.centerness_branch))
self.conv_centerness = nn.Conv2d(self.centerness_branch[-1], 1, 1)
self.scales = nn.ModuleList([
nn.ModuleList([Scale(1.0) for _ in range(3)]) for _ in self.strides
]) # only for offset, depth and size regression
[docs] def forward(self, feats):
"""Forward features from the upstream network.
Args:
feats (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
tuple:
cls_scores (list[Tensor]): Box scores for each scale level,
each is a 4D-tensor, the channel number is
num_points * num_classes.
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level, each is a 4D-tensor, the channel number is
num_points * bbox_code_size.
dir_cls_preds (list[Tensor]): Box scores for direction class
predictions on each scale level, each is a 4D-tensor,
the channel number is num_points * 2. (bin = 2).
attr_preds (list[Tensor]): Attribute scores for each scale
level, each is a 4D-tensor, the channel number is
num_points * num_attrs.
centernesses (list[Tensor]): Centerness for each scale level,
each is a 4D-tensor, the channel number is num_points * 1.
"""
return multi_apply(self.forward_single, feats, self.scales,
self.strides)
[docs] def forward_single(self, x, scale, stride):
"""Forward features of a single scale levle.
Args:
x (Tensor): FPN feature maps of the specified stride.
scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
the bbox prediction.
stride (int): The corresponding stride for feature maps, only
used to normalize the bbox prediction when self.norm_on_bbox
is True.
Returns:
tuple: scores for each class, bbox and direction class \
predictions, centerness predictions of input feature maps.
"""
cls_score, bbox_pred, dir_cls_pred, attr_pred, cls_feat, reg_feat = \
super().forward_single(x)
if self.centerness_on_reg:
clone_reg_feat = reg_feat.clone()
for conv_centerness_prev_layer in self.conv_centerness_prev:
clone_reg_feat = conv_centerness_prev_layer(clone_reg_feat)
centerness = self.conv_centerness(clone_reg_feat)
else:
clone_cls_feat = cls_feat.clone()
for conv_centerness_prev_layer in self.conv_centerness_prev:
clone_cls_feat = conv_centerness_prev_layer(clone_cls_feat)
centerness = self.conv_centerness(clone_cls_feat)
# scale the bbox_pred of different level
# only apply to offset, depth and size prediction
scale_offset, scale_depth, scale_size = scale[0:3]
clone_bbox_pred = bbox_pred.clone()
bbox_pred[:, :2] = scale_offset(clone_bbox_pred[:, :2]).float()
bbox_pred[:, 2] = scale_depth(clone_bbox_pred[:, 2]).float()
bbox_pred[:, 3:6] = scale_size(clone_bbox_pred[:, 3:6]).float()
bbox_pred[:, 2] = bbox_pred[:, 2].exp()
bbox_pred[:, 3:6] = bbox_pred[:, 3:6].exp() + 1e-6 # avoid size=0
assert self.norm_on_bbox is True, 'Setting norm_on_bbox to False '\
'has not been thoroughly tested for FCOS3D.'
if self.norm_on_bbox:
if not self.training:
# Note that this line is conducted only when testing
bbox_pred[:, :2] *= stride
return cls_score, bbox_pred, dir_cls_pred, attr_pred, centerness
[docs] @staticmethod
def add_sin_difference(boxes1, boxes2):
"""Convert the rotation difference to difference in sine function.
Args:
boxes1 (torch.Tensor): Original Boxes in shape (NxC), where C>=7
and the 7th dimension is rotation dimension.
boxes2 (torch.Tensor): Target boxes in shape (NxC), where C>=7 and
the 7th dimension is rotation dimension.
Returns:
tuple[torch.Tensor]: ``boxes1`` and ``boxes2`` whose 7th \
dimensions are changed.
"""
rad_pred_encoding = torch.sin(boxes1[..., 6:7]) * torch.cos(
boxes2[..., 6:7])
rad_tg_encoding = torch.cos(boxes1[..., 6:7]) * torch.sin(boxes2[...,
6:7])
boxes1 = torch.cat(
[boxes1[..., :6], rad_pred_encoding, boxes1[..., 7:]], dim=-1)
boxes2 = torch.cat([boxes2[..., :6], rad_tg_encoding, boxes2[..., 7:]],
dim=-1)
return boxes1, boxes2
[docs] @staticmethod
def get_direction_target(reg_targets,
dir_offset=0,
num_bins=2,
one_hot=True):
"""Encode direction to 0 ~ num_bins-1.
Args:
reg_targets (torch.Tensor): Bbox regression targets.
dir_offset (int): Direction offset.
num_bins (int): Number of bins to divide 2*PI.
one_hot (bool): Whether to encode as one hot.
Returns:
torch.Tensor: Encoded direction targets.
"""
rot_gt = reg_targets[..., 6]
offset_rot = limit_period(rot_gt - dir_offset, 0, 2 * np.pi)
dir_cls_targets = torch.floor(offset_rot /
(2 * np.pi / num_bins)).long()
dir_cls_targets = torch.clamp(dir_cls_targets, min=0, max=num_bins - 1)
if one_hot:
dir_targets = torch.zeros(
*list(dir_cls_targets.shape),
num_bins,
dtype=reg_targets.dtype,
device=dir_cls_targets.device)
dir_targets.scatter_(dir_cls_targets.unsqueeze(dim=-1).long(), 1.0)
dir_cls_targets = dir_targets
return dir_cls_targets
[docs] @force_fp32(
apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds', 'attr_preds',
'centernesses'))
def loss(self,
cls_scores,
bbox_preds,
dir_cls_preds,
attr_preds,
centernesses,
gt_bboxes,
gt_labels,
gt_bboxes_3d,
gt_labels_3d,
centers2d,
depths,
attr_labels,
img_metas,
gt_bboxes_ignore=None):
"""Compute loss of the head.
Args:
cls_scores (list[Tensor]): Box scores for each scale level,
each is a 4D-tensor, the channel number is
num_points * num_classes.
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level, each is a 4D-tensor, the channel number is
num_points * bbox_code_size.
dir_cls_preds (list[Tensor]): Box scores for direction class
predictions on each scale level, each is a 4D-tensor,
the channel number is num_points * 2. (bin = 2)
attr_preds (list[Tensor]): Attribute scores for each scale level,
each is a 4D-tensor, the channel number is
num_points * num_attrs.
centernesses (list[Tensor]): Centerness for each scale level, each
is a 4D-tensor, the channel number is num_points * 1.
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box
gt_bboxes_3d (list[Tensor]): 3D boxes ground truth with shape of
(num_gts, code_size).
gt_labels_3d (list[Tensor]): same as gt_labels
centers2d (list[Tensor]): 2D centers on the image with shape of
(num_gts, 2).
depths (list[Tensor]): Depth ground truth with shape of
(num_gts, ).
attr_labels (list[Tensor]): Attributes indices of each box.
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (None | list[Tensor]): specify which bounding
boxes can be ignored when computing the loss.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert len(cls_scores) == len(bbox_preds) == len(centernesses) == len(
attr_preds)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
bbox_preds[0].device)
labels_3d, bbox_targets_3d, centerness_targets, attr_targets = \
self.get_targets(
all_level_points, gt_bboxes, gt_labels, gt_bboxes_3d,
gt_labels_3d, centers2d, depths, attr_labels)
num_imgs = cls_scores[0].size(0)
# flatten cls_scores, bbox_preds, dir_cls_preds and centerness
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
for cls_score in cls_scores
]
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(-1, sum(self.group_reg_dims))
for bbox_pred in bbox_preds
]
flatten_dir_cls_preds = [
dir_cls_pred.permute(0, 2, 3, 1).reshape(-1, 2)
for dir_cls_pred in dir_cls_preds
]
flatten_centerness = [
centerness.permute(0, 2, 3, 1).reshape(-1)
for centerness in centernesses
]
flatten_cls_scores = torch.cat(flatten_cls_scores)
flatten_bbox_preds = torch.cat(flatten_bbox_preds)
flatten_dir_cls_preds = torch.cat(flatten_dir_cls_preds)
flatten_centerness = torch.cat(flatten_centerness)
flatten_labels_3d = torch.cat(labels_3d)
flatten_bbox_targets_3d = torch.cat(bbox_targets_3d)
flatten_centerness_targets = torch.cat(centerness_targets)
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
bg_class_ind = self.num_classes
pos_inds = ((flatten_labels_3d >= 0)
& (flatten_labels_3d < bg_class_ind)).nonzero().reshape(-1)
num_pos = len(pos_inds)
loss_cls = self.loss_cls(
flatten_cls_scores,
flatten_labels_3d,
avg_factor=num_pos + num_imgs) # avoid num_pos is 0
pos_bbox_preds = flatten_bbox_preds[pos_inds]
pos_dir_cls_preds = flatten_dir_cls_preds[pos_inds]
pos_centerness = flatten_centerness[pos_inds]
if self.pred_attrs:
flatten_attr_preds = [
attr_pred.permute(0, 2, 3, 1).reshape(-1, self.num_attrs)
for attr_pred in attr_preds
]
flatten_attr_preds = torch.cat(flatten_attr_preds)
flatten_attr_targets = torch.cat(attr_targets)
pos_attr_preds = flatten_attr_preds[pos_inds]
if num_pos > 0:
pos_bbox_targets_3d = flatten_bbox_targets_3d[pos_inds]
pos_centerness_targets = flatten_centerness_targets[pos_inds]
if self.pred_attrs:
pos_attr_targets = flatten_attr_targets[pos_inds]
bbox_weights = pos_centerness_targets.new_ones(
len(pos_centerness_targets), sum(self.group_reg_dims))
equal_weights = pos_centerness_targets.new_ones(
pos_centerness_targets.shape)
code_weight = self.train_cfg.get('code_weight', None)
if code_weight:
assert len(code_weight) == sum(self.group_reg_dims)
bbox_weights = bbox_weights * bbox_weights.new_tensor(
code_weight)
if self.use_direction_classifier:
pos_dir_cls_targets = self.get_direction_target(
pos_bbox_targets_3d, self.dir_offset, one_hot=False)
if self.diff_rad_by_sin:
pos_bbox_preds, pos_bbox_targets_3d = self.add_sin_difference(
pos_bbox_preds, pos_bbox_targets_3d)
loss_offset = self.loss_bbox(
pos_bbox_preds[:, :2],
pos_bbox_targets_3d[:, :2],
weight=bbox_weights[:, :2],
avg_factor=equal_weights.sum())
loss_depth = self.loss_bbox(
pos_bbox_preds[:, 2],
pos_bbox_targets_3d[:, 2],
weight=bbox_weights[:, 2],
avg_factor=equal_weights.sum())
loss_size = self.loss_bbox(
pos_bbox_preds[:, 3:6],
pos_bbox_targets_3d[:, 3:6],
weight=bbox_weights[:, 3:6],
avg_factor=equal_weights.sum())
loss_rotsin = self.loss_bbox(
pos_bbox_preds[:, 6],
pos_bbox_targets_3d[:, 6],
weight=bbox_weights[:, 6],
avg_factor=equal_weights.sum())
loss_velo = None
if self.pred_velo:
loss_velo = self.loss_bbox(
pos_bbox_preds[:, 7:9],
pos_bbox_targets_3d[:, 7:9],
weight=bbox_weights[:, 7:9],
avg_factor=equal_weights.sum())
loss_centerness = self.loss_centerness(pos_centerness,
pos_centerness_targets)
# direction classification loss
loss_dir = None
# TODO: add more check for use_direction_classifier
if self.use_direction_classifier:
loss_dir = self.loss_dir(
pos_dir_cls_preds,
pos_dir_cls_targets,
equal_weights,
avg_factor=equal_weights.sum())
# attribute classification loss
loss_attr = None
if self.pred_attrs:
loss_attr = self.loss_attr(
pos_attr_preds,
pos_attr_targets,
pos_centerness_targets,
avg_factor=pos_centerness_targets.sum())
else:
# need absolute due to possible negative delta x/y
loss_offset = pos_bbox_preds[:, :2].sum()
loss_depth = pos_bbox_preds[:, 2].sum()
loss_size = pos_bbox_preds[:, 3:6].sum()
loss_rotsin = pos_bbox_preds[:, 6].sum()
loss_velo = None
if self.pred_velo:
loss_velo = pos_bbox_preds[:, 7:9].sum()
loss_centerness = pos_centerness.sum()
loss_dir = None
if self.use_direction_classifier:
loss_dir = pos_dir_cls_preds.sum()
loss_attr = None
if self.pred_attrs:
loss_attr = pos_attr_preds.sum()
loss_dict = dict(
loss_cls=loss_cls,
loss_offset=loss_offset,
loss_depth=loss_depth,
loss_size=loss_size,
loss_rotsin=loss_rotsin,
loss_centerness=loss_centerness)
if loss_velo is not None:
loss_dict['loss_velo'] = loss_velo
if loss_dir is not None:
loss_dict['loss_dir'] = loss_dir
if loss_attr is not None:
loss_dict['loss_attr'] = loss_attr
return loss_dict
[docs] @force_fp32(
apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds', 'attr_preds',
'centernesses'))
def get_bboxes(self,
cls_scores,
bbox_preds,
dir_cls_preds,
attr_preds,
centernesses,
img_metas,
cfg=None,
rescale=None):
"""Transform network output for a batch into bbox predictions.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_points * num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_points * 4, H, W)
dir_cls_preds (list[Tensor]): Box scores for direction class
predictions on each scale level, each is a 4D-tensor,
the channel number is num_points * 2. (bin = 2)
attr_preds (list[Tensor]): Attribute scores for each scale level
Has shape (N, num_points * num_attrs, H, W)
centernesses (list[Tensor]): Centerness for each scale level with
shape (N, num_points * 1, H, W)
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used
rescale (bool): If True, return boxes in original image space
Returns:
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. \
The first item is an (n, 5) tensor, where the first 4 columns \
are bounding box positions (tl_x, tl_y, br_x, br_y) and the \
5-th column is a score between 0 and 1. The second item is a \
(n,) tensor where each item is the predicted class label of \
the corresponding box.
"""
assert len(cls_scores) == len(bbox_preds) == len(dir_cls_preds) == \
len(centernesses) == len(attr_preds)
num_levels = len(cls_scores)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
mlvl_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
bbox_preds[0].device)
result_list = []
for img_id in range(len(img_metas)):
cls_score_list = [
cls_scores[i][img_id].detach() for i in range(num_levels)
]
bbox_pred_list = [
bbox_preds[i][img_id].detach() for i in range(num_levels)
]
if self.use_direction_classifier:
dir_cls_pred_list = [
dir_cls_preds[i][img_id].detach()
for i in range(num_levels)
]
else:
dir_cls_pred_list = [
cls_scores[i][img_id].new_full(
[2, *cls_scores[i][img_id].shape[1:]], 0).detach()
for i in range(num_levels)
]
if self.pred_attrs:
attr_pred_list = [
attr_preds[i][img_id].detach() for i in range(num_levels)
]
else:
attr_pred_list = [
cls_scores[i][img_id].new_full(
[self.num_attrs, *cls_scores[i][img_id].shape[1:]],
self.attr_background_label).detach()
for i in range(num_levels)
]
centerness_pred_list = [
centernesses[i][img_id].detach() for i in range(num_levels)
]
input_meta = img_metas[img_id]
det_bboxes = self._get_bboxes_single(
cls_score_list, bbox_pred_list, dir_cls_pred_list,
attr_pred_list, centerness_pred_list, mlvl_points, input_meta,
cfg, rescale)
result_list.append(det_bboxes)
return result_list
def _get_bboxes_single(self,
cls_scores,
bbox_preds,
dir_cls_preds,
attr_preds,
centernesses,
mlvl_points,
input_meta,
cfg,
rescale=False):
"""Transform outputs for a single batch item into bbox predictions.
Args:
cls_scores (list[Tensor]): Box scores for a single scale level
Has shape (num_points * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for a single scale
level with shape (num_points * bbox_code_size, H, W).
dir_cls_preds (list[Tensor]): Box scores for direction class
predictions on a single scale level with shape \
(num_points * 2, H, W)
attr_preds (list[Tensor]): Attribute scores for each scale level
Has shape (N, num_points * num_attrs, H, W)
centernesses (list[Tensor]): Centerness for a single scale level
with shape (num_points, H, W).
mlvl_points (list[Tensor]): Box reference for a single scale level
with shape (num_total_points, 2).
input_meta (dict): Metadata of input image.
cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Returns:
tuples[Tensor]: Predicted 3D boxes, scores, labels and attributes.
"""
view = np.array(input_meta['cam_intrinsic'])
scale_factor = input_meta['scale_factor']
cfg = self.test_cfg if cfg is None else cfg
assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
mlvl_centers2d = []
mlvl_bboxes = []
mlvl_scores = []
mlvl_dir_scores = []
mlvl_attr_scores = []
mlvl_centerness = []
for cls_score, bbox_pred, dir_cls_pred, attr_pred, centerness, \
points in zip(cls_scores, bbox_preds, dir_cls_preds,
attr_preds, centernesses, mlvl_points):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
scores = cls_score.permute(1, 2, 0).reshape(
-1, self.cls_out_channels).sigmoid()
dir_cls_pred = dir_cls_pred.permute(1, 2, 0).reshape(-1, 2)
dir_cls_score = torch.max(dir_cls_pred, dim=-1)[1]
attr_pred = attr_pred.permute(1, 2, 0).reshape(-1, self.num_attrs)
attr_score = torch.max(attr_pred, dim=-1)[1]
centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid()
bbox_pred = bbox_pred.permute(1, 2,
0).reshape(-1,
sum(self.group_reg_dims))
bbox_pred = bbox_pred[:, :self.bbox_code_size]
nms_pre = cfg.get('nms_pre', -1)
if nms_pre > 0 and scores.shape[0] > nms_pre:
max_scores, _ = (scores * centerness[:, None]).max(dim=1)
_, topk_inds = max_scores.topk(nms_pre)
points = points[topk_inds, :]
bbox_pred = bbox_pred[topk_inds, :]
scores = scores[topk_inds, :]
dir_cls_pred = dir_cls_pred[topk_inds, :]
centerness = centerness[topk_inds]
dir_cls_score = dir_cls_score[topk_inds]
attr_score = attr_score[topk_inds]
# change the offset to actual center predictions
bbox_pred[:, :2] = points - bbox_pred[:, :2]
if rescale:
bbox_pred[:, :2] /= bbox_pred[:, :2].new_tensor(scale_factor)
pred_center2d = bbox_pred[:, :3].clone()
bbox_pred[:, :3] = self.pts2Dto3D(bbox_pred[:, :3], view)
mlvl_centers2d.append(pred_center2d)
mlvl_bboxes.append(bbox_pred)
mlvl_scores.append(scores)
mlvl_dir_scores.append(dir_cls_score)
mlvl_attr_scores.append(attr_score)
mlvl_centerness.append(centerness)
mlvl_centers2d = torch.cat(mlvl_centers2d)
mlvl_bboxes = torch.cat(mlvl_bboxes)
mlvl_dir_scores = torch.cat(mlvl_dir_scores)
# change local yaw to global yaw for 3D nms
if mlvl_bboxes.shape[0] > 0:
dir_rot = limit_period(mlvl_bboxes[..., 6] - self.dir_offset, 0,
np.pi)
mlvl_bboxes[..., 6] = (
dir_rot + self.dir_offset +
np.pi * mlvl_dir_scores.to(mlvl_bboxes.dtype))
cam_intrinsic = mlvl_centers2d.new_zeros((4, 4))
cam_intrinsic[:view.shape[0], :view.shape[1]] = \
mlvl_centers2d.new_tensor(view)
mlvl_bboxes[:, 6] = torch.atan2(
mlvl_centers2d[:, 0] - cam_intrinsic[0, 2],
cam_intrinsic[0, 0]) + mlvl_bboxes[:, 6]
mlvl_bboxes_for_nms = xywhr2xyxyr(input_meta['box_type_3d'](
mlvl_bboxes, box_dim=self.bbox_code_size,
origin=(0.5, 0.5, 0.5)).bev)
mlvl_scores = torch.cat(mlvl_scores)
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
# remind that we set FG labels to [0, num_class-1] since mmdet v2.0
# BG cat_id: num_class
mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
mlvl_attr_scores = torch.cat(mlvl_attr_scores)
mlvl_centerness = torch.cat(mlvl_centerness)
# no scale_factors in box3d_multiclass_nms
# Then we multiply it from outside
mlvl_nms_scores = mlvl_scores * mlvl_centerness[:, None]
results = box3d_multiclass_nms(mlvl_bboxes, mlvl_bboxes_for_nms,
mlvl_nms_scores, cfg.score_thr,
cfg.max_per_img, cfg, mlvl_dir_scores,
mlvl_attr_scores)
bboxes, scores, labels, dir_scores, attrs = results
attrs = attrs.to(labels.dtype) # change data type to int
bboxes = input_meta['box_type_3d'](
bboxes, box_dim=self.bbox_code_size, origin=(0.5, 0.5, 0.5))
# Note that the predictions use origin (0.5, 0.5, 0.5)
# Due to the ground truth centers2d are the gravity center of objects
# v0.10.0 fix inplace operation to the input tensor of cam_box3d
# So here we also need to add origin=(0.5, 0.5, 0.5)
if not self.pred_attrs:
attrs = None
return bboxes, scores, labels, attrs
[docs] @staticmethod
def pts2Dto3D(points, view):
"""
Args:
points (torch.Tensor): points in 2D images, [N, 3], \
3 corresponds with x, y in the image and depth.
view (np.ndarray): camera instrinsic, [3, 3]
Returns:
torch.Tensor: points in 3D space. [N, 3], \
3 corresponds with x, y, z in 3D space.
"""
assert view.shape[0] <= 4
assert view.shape[1] <= 4
assert points.shape[1] == 3
points2D = points[:, :2]
depths = points[:, 2].view(-1, 1)
unnorm_points2D = torch.cat([points2D * depths, depths], dim=1)
viewpad = torch.eye(4, dtype=points2D.dtype, device=points2D.device)
viewpad[:view.shape[0], :view.shape[1]] = points2D.new_tensor(view)
inv_viewpad = torch.inverse(viewpad).transpose(0, 1)
# Do operation in homogenous coordinates.
nbr_points = unnorm_points2D.shape[0]
homo_points2D = torch.cat(
[unnorm_points2D,
points2D.new_ones((nbr_points, 1))], dim=1)
points3D = torch.mm(homo_points2D, inv_viewpad)[:, :3]
return points3D
def _get_points_single(self,
featmap_size,
stride,
dtype,
device,
flatten=False):
"""Get points according to feature map sizes."""
y, x = super()._get_points_single(featmap_size, stride, dtype, device)
points = torch.stack((x.reshape(-1) * stride, y.reshape(-1) * stride),
dim=-1) + stride // 2
return points
[docs] def get_targets(self, points, gt_bboxes_list, gt_labels_list,
gt_bboxes_3d_list, gt_labels_3d_list, centers2d_list,
depths_list, attr_labels_list):
"""Compute regression, classification and centerss targets for points
in multiple images.
Args:
points (list[Tensor]): Points of each fpn level, each has shape
(num_points, 2).
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
each has shape (num_gt, 4).
gt_labels_list (list[Tensor]): Ground truth labels of each box,
each has shape (num_gt,).
gt_bboxes_3d_list (list[Tensor]): 3D Ground truth bboxes of each
image, each has shape (num_gt, bbox_code_size).
gt_labels_3d_list (list[Tensor]): 3D Ground truth labels of each
box, each has shape (num_gt,).
centers2d_list (list[Tensor]): Projected 3D centers onto 2D image,
each has shape (num_gt, 2).
depths_list (list[Tensor]): Depth of projected 3D centers onto 2D
image, each has shape (num_gt, 1).
attr_labels_list (list[Tensor]): Attribute labels of each box,
each has shape (num_gt,).
Returns:
tuple:
concat_lvl_labels (list[Tensor]): Labels of each level. \
concat_lvl_bbox_targets (list[Tensor]): BBox targets of each \
level.
"""
assert len(points) == len(self.regress_ranges)
num_levels = len(points)
# expand regress ranges to align with points
expanded_regress_ranges = [
points[i].new_tensor(self.regress_ranges[i])[None].expand_as(
points[i]) for i in range(num_levels)
]
# concat all levels points and regress ranges
concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0)
concat_points = torch.cat(points, dim=0)
# the number of points per img, per lvl
num_points = [center.size(0) for center in points]
if attr_labels_list is None:
attr_labels_list = [
gt_labels.new_full(gt_labels.shape, self.attr_background_label)
for gt_labels in gt_labels_list
]
# get labels and bbox_targets of each image
_, _, labels_3d_list, bbox_targets_3d_list, centerness_targets_list, \
attr_targets_list = multi_apply(
self._get_target_single,
gt_bboxes_list,
gt_labels_list,
gt_bboxes_3d_list,
gt_labels_3d_list,
centers2d_list,
depths_list,
attr_labels_list,
points=concat_points,
regress_ranges=concat_regress_ranges,
num_points_per_lvl=num_points)
# split to per img, per level
labels_3d_list = [
labels_3d.split(num_points, 0) for labels_3d in labels_3d_list
]
bbox_targets_3d_list = [
bbox_targets_3d.split(num_points, 0)
for bbox_targets_3d in bbox_targets_3d_list
]
centerness_targets_list = [
centerness_targets.split(num_points, 0)
for centerness_targets in centerness_targets_list
]
attr_targets_list = [
attr_targets.split(num_points, 0)
for attr_targets in attr_targets_list
]
# concat per level image
concat_lvl_labels_3d = []
concat_lvl_bbox_targets_3d = []
concat_lvl_centerness_targets = []
concat_lvl_attr_targets = []
for i in range(num_levels):
concat_lvl_labels_3d.append(
torch.cat([labels[i] for labels in labels_3d_list]))
concat_lvl_centerness_targets.append(
torch.cat([
centerness_targets[i]
for centerness_targets in centerness_targets_list
]))
bbox_targets_3d = torch.cat([
bbox_targets_3d[i] for bbox_targets_3d in bbox_targets_3d_list
])
concat_lvl_attr_targets.append(
torch.cat(
[attr_targets[i] for attr_targets in attr_targets_list]))
if self.norm_on_bbox:
bbox_targets_3d[:, :
2] = bbox_targets_3d[:, :2] / self.strides[i]
concat_lvl_bbox_targets_3d.append(bbox_targets_3d)
return concat_lvl_labels_3d, concat_lvl_bbox_targets_3d, \
concat_lvl_centerness_targets, concat_lvl_attr_targets
def _get_target_single(self, gt_bboxes, gt_labels, gt_bboxes_3d,
gt_labels_3d, centers2d, depths, attr_labels,
points, regress_ranges, num_points_per_lvl):
"""Compute regression and classification targets for a single image."""
num_points = points.size(0)
num_gts = gt_labels.size(0)
if not isinstance(gt_bboxes_3d, torch.Tensor):
gt_bboxes_3d = gt_bboxes_3d.tensor.to(gt_bboxes.device)
if num_gts == 0:
return gt_labels.new_full((num_points,), self.background_label), \
gt_bboxes.new_zeros((num_points, 4)), \
gt_labels_3d.new_full(
(num_points,), self.background_label), \
gt_bboxes_3d.new_zeros((num_points, self.bbox_code_size)), \
gt_bboxes_3d.new_zeros((num_points,)), \
attr_labels.new_full(
(num_points,), self.attr_background_label)
# change orientation to local yaw
gt_bboxes_3d[..., 6] = -torch.atan2(
gt_bboxes_3d[..., 0], gt_bboxes_3d[..., 2]) + gt_bboxes_3d[..., 6]
areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (
gt_bboxes[:, 3] - gt_bboxes[:, 1])
areas = areas[None].repeat(num_points, 1)
regress_ranges = regress_ranges[:, None, :].expand(
num_points, num_gts, 2)
gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4)
centers2d = centers2d[None].expand(num_points, num_gts, 2)
gt_bboxes_3d = gt_bboxes_3d[None].expand(num_points, num_gts,
self.bbox_code_size)
depths = depths[None, :, None].expand(num_points, num_gts, 1)
xs, ys = points[:, 0], points[:, 1]
xs = xs[:, None].expand(num_points, num_gts)
ys = ys[:, None].expand(num_points, num_gts)
delta_xs = (xs - centers2d[..., 0])[..., None]
delta_ys = (ys - centers2d[..., 1])[..., None]
bbox_targets_3d = torch.cat(
(delta_xs, delta_ys, depths, gt_bboxes_3d[..., 3:]), dim=-1)
left = xs - gt_bboxes[..., 0]
right = gt_bboxes[..., 2] - xs
top = ys - gt_bboxes[..., 1]
bottom = gt_bboxes[..., 3] - ys
bbox_targets = torch.stack((left, top, right, bottom), -1)
assert self.center_sampling is True, 'Setting center_sampling to '\
'False has not been implemented for FCOS3D.'
# condition1: inside a `center bbox`
radius = self.center_sample_radius
center_xs = centers2d[..., 0]
center_ys = centers2d[..., 1]
center_gts = torch.zeros_like(gt_bboxes)
stride = center_xs.new_zeros(center_xs.shape)
# project the points on current lvl back to the `original` sizes
lvl_begin = 0
for lvl_idx, num_points_lvl in enumerate(num_points_per_lvl):
lvl_end = lvl_begin + num_points_lvl
stride[lvl_begin:lvl_end] = self.strides[lvl_idx] * radius
lvl_begin = lvl_end
center_gts[..., 0] = center_xs - stride
center_gts[..., 1] = center_ys - stride
center_gts[..., 2] = center_xs + stride
center_gts[..., 3] = center_ys + stride
cb_dist_left = xs - center_gts[..., 0]
cb_dist_right = center_gts[..., 2] - xs
cb_dist_top = ys - center_gts[..., 1]
cb_dist_bottom = center_gts[..., 3] - ys
center_bbox = torch.stack(
(cb_dist_left, cb_dist_top, cb_dist_right, cb_dist_bottom), -1)
inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0
# condition2: limit the regression range for each location
max_regress_distance = bbox_targets.max(-1)[0]
inside_regress_range = (
(max_regress_distance >= regress_ranges[..., 0])
& (max_regress_distance <= regress_ranges[..., 1]))
# center-based criterion to deal with ambiguity
dists = torch.sqrt(torch.sum(bbox_targets_3d[..., :2]**2, dim=-1))
dists[inside_gt_bbox_mask == 0] = INF
dists[inside_regress_range == 0] = INF
min_dist, min_dist_inds = dists.min(dim=1)
labels = gt_labels[min_dist_inds]
labels_3d = gt_labels_3d[min_dist_inds]
attr_labels = attr_labels[min_dist_inds]
labels[min_dist == INF] = self.background_label # set as BG
labels_3d[min_dist == INF] = self.background_label # set as BG
attr_labels[min_dist == INF] = self.attr_background_label
bbox_targets = bbox_targets[range(num_points), min_dist_inds]
bbox_targets_3d = bbox_targets_3d[range(num_points), min_dist_inds]
relative_dists = torch.sqrt(
torch.sum(bbox_targets_3d[..., :2]**2,
dim=-1)) / (1.414 * stride[:, 0])
# [N, 1] / [N, 1]
centerness_targets = torch.exp(-self.centerness_alpha * relative_dists)
return labels, bbox_targets, labels_3d, bbox_targets_3d, \
centerness_targets, attr_labels