Source code for mmdet3d.models.detectors.imvotenet
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import numpy as np
import torch
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from mmdet3d.models.utils import MLP
from .. import builder
from ..builder import DETECTORS
from .base import Base3DDetector
def sample_valid_seeds(mask, num_sampled_seed=1024):
r"""Randomly sample seeds from all imvotes.
Modified from `<https://github.com/facebookresearch/imvotenet/blob/a8856345146bacf29a57266a2f0b874406fd8823/models/imvotenet.py#L26>`_
Args:
mask (torch.Tensor): Bool tensor in shape (
seed_num*max_imvote_per_pixel), indicates
whether this imvote corresponds to a 2D bbox.
num_sampled_seed (int): How many to sample from all imvotes.
Returns:
torch.Tensor: Indices with shape (num_sampled_seed).
""" # noqa: E501
device = mask.device
batch_size = mask.shape[0]
sample_inds = mask.new_zeros((batch_size, num_sampled_seed),
dtype=torch.int64)
for bidx in range(batch_size):
# return index of non zero elements
valid_inds = torch.nonzero(mask[bidx, :]).squeeze(-1)
if len(valid_inds) < num_sampled_seed:
# compute set t1 - t2
t1 = torch.arange(num_sampled_seed, device=device)
t2 = valid_inds % num_sampled_seed
combined = torch.cat((t1, t2))
uniques, counts = combined.unique(return_counts=True)
difference = uniques[counts == 1]
rand_inds = torch.randperm(
len(difference),
device=device)[:num_sampled_seed - len(valid_inds)]
cur_sample_inds = difference[rand_inds]
cur_sample_inds = torch.cat((valid_inds, cur_sample_inds))
else:
rand_inds = torch.randperm(
len(valid_inds), device=device)[:num_sampled_seed]
cur_sample_inds = valid_inds[rand_inds]
sample_inds[bidx, :] = cur_sample_inds
return sample_inds
[docs]@DETECTORS.register_module()
class ImVoteNet(Base3DDetector):
r"""`ImVoteNet <https://arxiv.org/abs/2001.10692>`_ for 3D detection."""
def __init__(self,
pts_backbone=None,
pts_bbox_heads=None,
pts_neck=None,
img_backbone=None,
img_neck=None,
img_roi_head=None,
img_rpn_head=None,
img_mlp=None,
freeze_img_branch=False,
fusion_layer=None,
num_sampled_seed=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
super(ImVoteNet, self).__init__(init_cfg=init_cfg)
# point branch
if pts_backbone is not None:
self.pts_backbone = builder.build_backbone(pts_backbone)
if pts_neck is not None:
self.pts_neck = builder.build_neck(pts_neck)
if pts_bbox_heads is not None:
pts_bbox_head_common = pts_bbox_heads.common
pts_bbox_head_common.update(
train_cfg=train_cfg.pts if train_cfg is not None else None)
pts_bbox_head_common.update(test_cfg=test_cfg.pts)
pts_bbox_head_joint = pts_bbox_head_common.copy()
pts_bbox_head_joint.update(pts_bbox_heads.joint)
pts_bbox_head_pts = pts_bbox_head_common.copy()
pts_bbox_head_pts.update(pts_bbox_heads.pts)
pts_bbox_head_img = pts_bbox_head_common.copy()
pts_bbox_head_img.update(pts_bbox_heads.img)
self.pts_bbox_head_joint = builder.build_head(pts_bbox_head_joint)
self.pts_bbox_head_pts = builder.build_head(pts_bbox_head_pts)
self.pts_bbox_head_img = builder.build_head(pts_bbox_head_img)
self.pts_bbox_heads = [
self.pts_bbox_head_joint, self.pts_bbox_head_pts,
self.pts_bbox_head_img
]
self.loss_weights = pts_bbox_heads.loss_weights
# image branch
if img_backbone:
self.img_backbone = builder.build_backbone(img_backbone)
if img_neck is not None:
self.img_neck = builder.build_neck(img_neck)
if img_rpn_head is not None:
rpn_train_cfg = train_cfg.img_rpn if train_cfg \
is not None else None
img_rpn_head_ = img_rpn_head.copy()
img_rpn_head_.update(
train_cfg=rpn_train_cfg, test_cfg=test_cfg.img_rpn)
self.img_rpn_head = builder.build_head(img_rpn_head_)
if img_roi_head is not None:
rcnn_train_cfg = train_cfg.img_rcnn if train_cfg \
is not None else None
img_roi_head.update(
train_cfg=rcnn_train_cfg, test_cfg=test_cfg.img_rcnn)
self.img_roi_head = builder.build_head(img_roi_head)
# fusion
if fusion_layer is not None:
self.fusion_layer = builder.build_fusion_layer(fusion_layer)
self.max_imvote_per_pixel = fusion_layer.max_imvote_per_pixel
self.freeze_img_branch = freeze_img_branch
if freeze_img_branch:
self.freeze_img_branch_params()
if img_mlp is not None:
self.img_mlp = MLP(**img_mlp)
self.num_sampled_seed = num_sampled_seed
self.train_cfg = train_cfg
self.test_cfg = test_cfg
if pretrained is None:
img_pretrained = None
pts_pretrained = None
elif isinstance(pretrained, dict):
img_pretrained = pretrained.get('img', None)
pts_pretrained = pretrained.get('pts', None)
else:
raise ValueError(
f'pretrained should be a dict, got {type(pretrained)}')
if self.with_img_backbone:
if img_pretrained is not None:
warnings.warn('DeprecationWarning: pretrained is a deprecated '
'key, please consider using init_cfg.')
self.img_backbone.init_cfg = dict(
type='Pretrained', checkpoint=img_pretrained)
if self.with_img_roi_head:
if img_pretrained is not None:
warnings.warn('DeprecationWarning: pretrained is a deprecated '
'key, please consider using init_cfg.')
self.img_roi_head.init_cfg = dict(
type='Pretrained', checkpoint=img_pretrained)
if self.with_pts_backbone:
if img_pretrained is not None:
warnings.warn('DeprecationWarning: pretrained is a deprecated '
'key, please consider using init_cfg.')
self.pts_backbone.init_cfg = dict(
type='Pretrained', checkpoint=pts_pretrained)
[docs] def freeze_img_branch_params(self):
"""Freeze all image branch parameters."""
if self.with_img_bbox_head:
for param in self.img_bbox_head.parameters():
param.requires_grad = False
if self.with_img_backbone:
for param in self.img_backbone.parameters():
param.requires_grad = False
if self.with_img_neck:
for param in self.img_neck.parameters():
param.requires_grad = False
if self.with_img_rpn:
for param in self.img_rpn_head.parameters():
param.requires_grad = False
if self.with_img_roi_head:
for param in self.img_roi_head.parameters():
param.requires_grad = False
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
"""Overload in order to load img network ckpts into img branch."""
module_names = ['backbone', 'neck', 'roi_head', 'rpn_head']
for key in list(state_dict):
for module_name in module_names:
if key.startswith(module_name) and ('img_' +
key) not in state_dict:
state_dict['img_' + key] = state_dict.pop(key)
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs)
[docs] def train(self, mode=True):
"""Overload in order to keep image branch modules in eval mode."""
super(ImVoteNet, self).train(mode)
if self.freeze_img_branch:
if self.with_img_bbox_head:
self.img_bbox_head.eval()
if self.with_img_backbone:
self.img_backbone.eval()
if self.with_img_neck:
self.img_neck.eval()
if self.with_img_rpn:
self.img_rpn_head.eval()
if self.with_img_roi_head:
self.img_roi_head.eval()
@property
def with_img_bbox(self):
"""bool: Whether the detector has a 2D image box head."""
return ((hasattr(self, 'img_roi_head') and self.img_roi_head.with_bbox)
or (hasattr(self, 'img_bbox_head')
and self.img_bbox_head is not None))
@property
def with_img_bbox_head(self):
"""bool: Whether the detector has a 2D image box head (not roi)."""
return hasattr(self,
'img_bbox_head') and self.img_bbox_head is not None
@property
def with_img_backbone(self):
"""bool: Whether the detector has a 2D image backbone."""
return hasattr(self, 'img_backbone') and self.img_backbone is not None
@property
def with_img_neck(self):
"""bool: Whether the detector has a neck in image branch."""
return hasattr(self, 'img_neck') and self.img_neck is not None
@property
def with_img_rpn(self):
"""bool: Whether the detector has a 2D RPN in image detector branch."""
return hasattr(self, 'img_rpn_head') and self.img_rpn_head is not None
@property
def with_img_roi_head(self):
"""bool: Whether the detector has a RoI Head in image branch."""
return hasattr(self, 'img_roi_head') and self.img_roi_head is not None
@property
def with_pts_bbox(self):
"""bool: Whether the detector has a 3D box head."""
return hasattr(self,
'pts_bbox_head') and self.pts_bbox_head is not None
@property
def with_pts_backbone(self):
"""bool: Whether the detector has a 3D backbone."""
return hasattr(self, 'pts_backbone') and self.pts_backbone is not None
@property
def with_pts_neck(self):
"""bool: Whether the detector has a neck in 3D detector branch."""
return hasattr(self, 'pts_neck') and self.pts_neck is not None
[docs] def extract_img_feat(self, img):
"""Directly extract features from the img backbone+neck."""
x = self.img_backbone(img)
if self.with_img_neck:
x = self.img_neck(x)
return x
[docs] def extract_img_feats(self, imgs):
"""Extract features from multiple images.
Args:
imgs (list[torch.Tensor]): A list of images. The images are
augmented from the same image but in different ways.
Returns:
list[torch.Tensor]: Features of different images
"""
assert isinstance(imgs, list)
return [self.extract_img_feat(img) for img in imgs]
[docs] def extract_pts_feat(self, pts):
"""Extract features of points."""
x = self.pts_backbone(pts)
if self.with_pts_neck:
x = self.pts_neck(x)
seed_points = x['fp_xyz'][-1]
seed_features = x['fp_features'][-1]
seed_indices = x['fp_indices'][-1]
return (seed_points, seed_features, seed_indices)
[docs] def extract_pts_feats(self, pts):
"""Extract features of points from multiple samples."""
assert isinstance(pts, list)
return [self.extract_pts_feat(pt) for pt in pts]
[docs] @torch.no_grad()
def extract_bboxes_2d(self,
img,
img_metas,
train=True,
bboxes_2d=None,
**kwargs):
"""Extract bounding boxes from 2d detector.
Args:
img (torch.Tensor): of shape (N, C, H, W) encoding input images.
Typically these should be mean centered and std scaled.
img_metas (list[dict]): Image meta info.
train (bool): train-time or not.
bboxes_2d (list[torch.Tensor]): provided 2d bboxes,
not supported yet.
Return:
list[torch.Tensor]: a list of processed 2d bounding boxes.
"""
if bboxes_2d is None:
x = self.extract_img_feat(img)
proposal_list = self.img_rpn_head.simple_test_rpn(x, img_metas)
rets = self.img_roi_head.simple_test(
x, proposal_list, img_metas, rescale=False)
rets_processed = []
for ret in rets:
tmp = np.concatenate(ret, axis=0)
sem_class = img.new_zeros((len(tmp)))
start = 0
for i, bboxes in enumerate(ret):
sem_class[start:start + len(bboxes)] = i
start += len(bboxes)
ret = img.new_tensor(tmp)
# append class index
ret = torch.cat([ret, sem_class[:, None]], dim=-1)
inds = torch.argsort(ret[:, 4], descending=True)
ret = ret.index_select(0, inds)
# drop half bboxes during training for better generalization
if train:
rand_drop = torch.randperm(len(ret))[:(len(ret) + 1) // 2]
rand_drop = torch.sort(rand_drop)[0]
ret = ret[rand_drop]
rets_processed.append(ret.float())
return rets_processed
else:
rets_processed = []
for ret in bboxes_2d:
if len(ret) > 0 and train:
rand_drop = torch.randperm(len(ret))[:(len(ret) + 1) // 2]
rand_drop = torch.sort(rand_drop)[0]
ret = ret[rand_drop]
rets_processed.append(ret.float())
return rets_processed
[docs] def forward_train(self,
points=None,
img=None,
img_metas=None,
gt_bboxes=None,
gt_labels=None,
gt_bboxes_ignore=None,
gt_masks=None,
proposals=None,
bboxes_2d=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
pts_semantic_mask=None,
pts_instance_mask=None,
**kwargs):
"""Forwarding of train for image branch pretrain or stage 2 train.
Args:
points (list[torch.Tensor]): Points of each batch.
img (torch.Tensor): of shape (N, C, H, W) encoding input images.
Typically these should be mean centered and std scaled.
img_metas (list[dict]): list of image and point cloud meta info
dict. For example, keys include 'ori_shape', 'img_norm_cfg',
and 'transformation_3d_flow'. For details on the values of
the keys see `mmdet/datasets/pipelines/formatting.py:Collect`.
gt_bboxes (list[torch.Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[torch.Tensor]): class indices for each
2d bounding box.
gt_bboxes_ignore (list[torch.Tensor]): specify which
2d bounding boxes can be ignored when computing the loss.
gt_masks (torch.Tensor): true segmentation masks for each
2d bbox, used if the architecture supports a segmentation task.
proposals: override rpn proposals (2d) with custom proposals.
Use when `with_rpn` is False.
bboxes_2d (list[torch.Tensor]): provided 2d bboxes,
not supported yet.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): 3d gt bboxes.
gt_labels_3d (list[torch.Tensor]): gt class labels for 3d bboxes.
pts_semantic_mask (list[torch.Tensor]): point-wise semantic
label of each batch.
pts_instance_mask (list[torch.Tensor]): point-wise instance
label of each batch.
Returns:
dict[str, torch.Tensor]: a dictionary of loss components.
"""
if points is None:
x = self.extract_img_feat(img)
losses = dict()
# RPN forward and loss
if self.with_img_rpn:
proposal_cfg = self.train_cfg.get('img_rpn_proposal',
self.test_cfg.img_rpn)
rpn_losses, proposal_list = self.img_rpn_head.forward_train(
x,
img_metas,
gt_bboxes,
gt_labels=None,
gt_bboxes_ignore=gt_bboxes_ignore,
proposal_cfg=proposal_cfg)
losses.update(rpn_losses)
else:
proposal_list = proposals
roi_losses = self.img_roi_head.forward_train(
x, img_metas, proposal_list, gt_bboxes, gt_labels,
gt_bboxes_ignore, gt_masks, **kwargs)
losses.update(roi_losses)
return losses
else:
bboxes_2d = self.extract_bboxes_2d(
img, img_metas, bboxes_2d=bboxes_2d, **kwargs)
points = torch.stack(points)
seeds_3d, seed_3d_features, seed_indices = \
self.extract_pts_feat(points)
img_features, masks = self.fusion_layer(img, bboxes_2d, seeds_3d,
img_metas)
inds = sample_valid_seeds(masks, self.num_sampled_seed)
batch_size, img_feat_size = img_features.shape[:2]
pts_feat_size = seed_3d_features.shape[1]
inds_img = inds.view(batch_size, 1,
-1).expand(-1, img_feat_size, -1)
img_features = img_features.gather(-1, inds_img)
inds = inds % inds.shape[1]
inds_seed_xyz = inds.view(batch_size, -1, 1).expand(-1, -1, 3)
seeds_3d = seeds_3d.gather(1, inds_seed_xyz)
inds_seed_feats = inds.view(batch_size, 1,
-1).expand(-1, pts_feat_size, -1)
seed_3d_features = seed_3d_features.gather(-1, inds_seed_feats)
seed_indices = seed_indices.gather(1, inds)
img_features = self.img_mlp(img_features)
fused_features = torch.cat([seed_3d_features, img_features], dim=1)
feat_dict_joint = dict(
seed_points=seeds_3d,
seed_features=fused_features,
seed_indices=seed_indices)
feat_dict_pts = dict(
seed_points=seeds_3d,
seed_features=seed_3d_features,
seed_indices=seed_indices)
feat_dict_img = dict(
seed_points=seeds_3d,
seed_features=img_features,
seed_indices=seed_indices)
loss_inputs = (points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask, img_metas)
bbox_preds_joints = self.pts_bbox_head_joint(
feat_dict_joint, self.train_cfg.pts.sample_mod)
bbox_preds_pts = self.pts_bbox_head_pts(
feat_dict_pts, self.train_cfg.pts.sample_mod)
bbox_preds_img = self.pts_bbox_head_img(
feat_dict_img, self.train_cfg.pts.sample_mod)
losses_towers = []
losses_joint = self.pts_bbox_head_joint.loss(
bbox_preds_joints,
*loss_inputs,
gt_bboxes_ignore=gt_bboxes_ignore)
losses_pts = self.pts_bbox_head_pts.loss(
bbox_preds_pts,
*loss_inputs,
gt_bboxes_ignore=gt_bboxes_ignore)
losses_img = self.pts_bbox_head_img.loss(
bbox_preds_img,
*loss_inputs,
gt_bboxes_ignore=gt_bboxes_ignore)
losses_towers.append(losses_joint)
losses_towers.append(losses_pts)
losses_towers.append(losses_img)
combined_losses = dict()
for loss_term in losses_joint:
if 'loss' in loss_term:
combined_losses[loss_term] = 0
for i in range(len(losses_towers)):
combined_losses[loss_term] += \
losses_towers[i][loss_term] * \
self.loss_weights[i]
else:
# only save the metric of the joint head
# if it is not a loss
combined_losses[loss_term] = \
losses_towers[0][loss_term]
return combined_losses
[docs] def forward_test(self,
points=None,
img_metas=None,
img=None,
bboxes_2d=None,
**kwargs):
"""Forwarding of test for image branch pretrain or stage 2 train.
Args:
points (list[list[torch.Tensor]], optional): the outer
list indicates test-time augmentations and the inner
list contains all points in the batch, where each Tensor
should have a shape NxC. Defaults to None.
img_metas (list[list[dict]], optional): the outer list
indicates test-time augs (multiscale, flip, etc.)
and the inner list indicates images in a batch.
Defaults to None.
img (list[list[torch.Tensor]], optional): the outer
list indicates test-time augmentations and inner Tensor
should have a shape NxCxHxW, which contains all images
in the batch. Defaults to None. Defaults to None.
bboxes_2d (list[list[torch.Tensor]], optional):
Provided 2d bboxes, not supported yet. Defaults to None.
Returns:
list[list[torch.Tensor]]|list[dict]: Predicted 2d or 3d boxes.
"""
if points is None:
for var, name in [(img, 'img'), (img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError(
f'{name} must be a list, but got {type(var)}')
num_augs = len(img)
if num_augs != len(img_metas):
raise ValueError(f'num of augmentations ({len(img)}) '
f'!= num of image meta ({len(img_metas)})')
if num_augs == 1:
# proposals (List[List[Tensor]]): the outer list indicates
# test-time augs (multiscale, flip, etc.) and the inner list
# indicates images in a batch.
# The Tensor should have a shape Px4, where P is the number of
# proposals.
if 'proposals' in kwargs:
kwargs['proposals'] = kwargs['proposals'][0]
return self.simple_test_img_only(
img=img[0], img_metas=img_metas[0], **kwargs)
else:
assert img[0].size(0) == 1, 'aug test does not support ' \
'inference with batch size ' \
f'{img[0].size(0)}'
# TODO: support test augmentation for predefined proposals
assert 'proposals' not in kwargs
return self.aug_test_img_only(
img=img, img_metas=img_metas, **kwargs)
else:
for var, name in [(points, 'points'), (img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError('{} must be a list, but got {}'.format(
name, type(var)))
num_augs = len(points)
if num_augs != len(img_metas):
raise ValueError(
'num of augmentations ({}) != num of image meta ({})'.
format(len(points), len(img_metas)))
if num_augs == 1:
return self.simple_test(
points[0],
img_metas[0],
img[0],
bboxes_2d=bboxes_2d[0] if bboxes_2d is not None else None,
**kwargs)
else:
return self.aug_test(points, img_metas, img, bboxes_2d,
**kwargs)
[docs] def simple_test_img_only(self,
img,
img_metas,
proposals=None,
rescale=False):
r"""Test without augmentation, image network pretrain. May refer to
`<https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/detectors/two_stage.py>`_.
Args:
img (torch.Tensor): Should have a shape NxCxHxW, which contains
all images in the batch.
img_metas (list[dict]):
proposals (list[Tensor], optional): override rpn proposals
with custom proposals. Defaults to None.
rescale (bool, optional): Whether or not rescale bboxes to the
original shape of input image. Defaults to False.
Returns:
list[list[torch.Tensor]]: Predicted 2d boxes.
""" # noqa: E501
assert self.with_img_bbox, 'Img bbox head must be implemented.'
assert self.with_img_backbone, 'Img backbone must be implemented.'
assert self.with_img_rpn, 'Img rpn must be implemented.'
assert self.with_img_roi_head, 'Img roi head must be implemented.'
x = self.extract_img_feat(img)
if proposals is None:
proposal_list = self.img_rpn_head.simple_test_rpn(x, img_metas)
else:
proposal_list = proposals
ret = self.img_roi_head.simple_test(
x, proposal_list, img_metas, rescale=rescale)
return ret
[docs] def simple_test(self,
points=None,
img_metas=None,
img=None,
bboxes_2d=None,
rescale=False,
**kwargs):
"""Test without augmentation, stage 2.
Args:
points (list[torch.Tensor], optional): Elements in the list
should have a shape NxC, the list indicates all point-clouds
in the batch. Defaults to None.
img_metas (list[dict], optional): List indicates
images in a batch. Defaults to None.
img (torch.Tensor, optional): Should have a shape NxCxHxW,
which contains all images in the batch. Defaults to None.
bboxes_2d (list[torch.Tensor], optional):
Provided 2d bboxes, not supported yet. Defaults to None.
rescale (bool, optional): Whether or not rescale bboxes.
Defaults to False.
Returns:
list[dict]: Predicted 3d boxes.
"""
bboxes_2d = self.extract_bboxes_2d(
img, img_metas, train=False, bboxes_2d=bboxes_2d, **kwargs)
points = torch.stack(points)
seeds_3d, seed_3d_features, seed_indices = \
self.extract_pts_feat(points)
img_features, masks = self.fusion_layer(img, bboxes_2d, seeds_3d,
img_metas)
inds = sample_valid_seeds(masks, self.num_sampled_seed)
batch_size, img_feat_size = img_features.shape[:2]
pts_feat_size = seed_3d_features.shape[1]
inds_img = inds.view(batch_size, 1, -1).expand(-1, img_feat_size, -1)
img_features = img_features.gather(-1, inds_img)
inds = inds % inds.shape[1]
inds_seed_xyz = inds.view(batch_size, -1, 1).expand(-1, -1, 3)
seeds_3d = seeds_3d.gather(1, inds_seed_xyz)
inds_seed_feats = inds.view(batch_size, 1,
-1).expand(-1, pts_feat_size, -1)
seed_3d_features = seed_3d_features.gather(-1, inds_seed_feats)
seed_indices = seed_indices.gather(1, inds)
img_features = self.img_mlp(img_features)
fused_features = torch.cat([seed_3d_features, img_features], dim=1)
feat_dict = dict(
seed_points=seeds_3d,
seed_features=fused_features,
seed_indices=seed_indices)
bbox_preds = self.pts_bbox_head_joint(feat_dict,
self.test_cfg.pts.sample_mod)
bbox_list = self.pts_bbox_head_joint.get_bboxes(
points, bbox_preds, img_metas, rescale=rescale)
bbox_results = [
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
return bbox_results
[docs] def aug_test_img_only(self, img, img_metas, rescale=False):
r"""Test function with augmentation, image network pretrain. May refer
to `<https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/detectors/two_stage.py>`_.
Args:
img (list[list[torch.Tensor]], optional): the outer
list indicates test-time augmentations and inner Tensor
should have a shape NxCxHxW, which contains all images
in the batch. Defaults to None. Defaults to None.
img_metas (list[list[dict]], optional): the outer list
indicates test-time augs (multiscale, flip, etc.)
and the inner list indicates images in a batch.
Defaults to None.
rescale (bool, optional): Whether or not rescale bboxes to the
original shape of input image. If rescale is False, then
returned bboxes and masks will fit the scale of imgs[0].
Defaults to None.
Returns:
list[list[torch.Tensor]]: Predicted 2d boxes.
""" # noqa: E501
assert self.with_img_bbox, 'Img bbox head must be implemented.'
assert self.with_img_backbone, 'Img backbone must be implemented.'
assert self.with_img_rpn, 'Img rpn must be implemented.'
assert self.with_img_roi_head, 'Img roi head must be implemented.'
x = self.extract_img_feats(img)
proposal_list = self.img_rpn_head.aug_test_rpn(x, img_metas)
return self.img_roi_head.aug_test(
x, proposal_list, img_metas, rescale=rescale)
[docs] def aug_test(self,
points=None,
img_metas=None,
imgs=None,
bboxes_2d=None,
rescale=False,
**kwargs):
"""Test function with augmentation, stage 2.
Args:
points (list[list[torch.Tensor]], optional): the outer
list indicates test-time augmentations and the inner
list contains all points in the batch, where each Tensor
should have a shape NxC. Defaults to None.
img_metas (list[list[dict]], optional): the outer list
indicates test-time augs (multiscale, flip, etc.)
and the inner list indicates images in a batch.
Defaults to None.
imgs (list[list[torch.Tensor]], optional): the outer
list indicates test-time augmentations and inner Tensor
should have a shape NxCxHxW, which contains all images
in the batch. Defaults to None. Defaults to None.
bboxes_2d (list[list[torch.Tensor]], optional):
Provided 2d bboxes, not supported yet. Defaults to None.
rescale (bool, optional): Whether or not rescale bboxes.
Defaults to False.
Returns:
list[dict]: Predicted 3d boxes.
"""
points_cat = [torch.stack(pts) for pts in points]
feats = self.extract_pts_feats(points_cat, img_metas)
# only support aug_test for one sample
aug_bboxes = []
for x, pts_cat, img_meta, bbox_2d, img in zip(feats, points_cat,
img_metas, bboxes_2d,
imgs):
bbox_2d = self.extract_bboxes_2d(
img, img_metas, train=False, bboxes_2d=bbox_2d, **kwargs)
seeds_3d, seed_3d_features, seed_indices = x
img_features, masks = self.fusion_layer(img, bbox_2d, seeds_3d,
img_metas)
inds = sample_valid_seeds(masks, self.num_sampled_seed)
batch_size, img_feat_size = img_features.shape[:2]
pts_feat_size = seed_3d_features.shape[1]
inds_img = inds.view(batch_size, 1,
-1).expand(-1, img_feat_size, -1)
img_features = img_features.gather(-1, inds_img)
inds = inds % inds.shape[1]
inds_seed_xyz = inds.view(batch_size, -1, 1).expand(-1, -1, 3)
seeds_3d = seeds_3d.gather(1, inds_seed_xyz)
inds_seed_feats = inds.view(batch_size, 1,
-1).expand(-1, pts_feat_size, -1)
seed_3d_features = seed_3d_features.gather(-1, inds_seed_feats)
seed_indices = seed_indices.gather(1, inds)
img_features = self.img_mlp(img_features)
fused_features = torch.cat([seed_3d_features, img_features], dim=1)
feat_dict = dict(
seed_points=seeds_3d,
seed_features=fused_features,
seed_indices=seed_indices)
bbox_preds = self.pts_bbox_head_joint(feat_dict,
self.test_cfg.pts.sample_mod)
bbox_list = self.pts_bbox_head_joint.get_bboxes(
pts_cat, bbox_preds, img_metas, rescale=rescale)
bbox_list = [
dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels)
for bboxes, scores, labels in bbox_list
]
aug_bboxes.append(bbox_list[0])
# after merging, bboxes will be rescaled to the original image size
merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
self.bbox_head.test_cfg)
return [merged_bboxes]