Shortcuts

Source code for mmdet3d.models.detectors.base

# Copyright (c) OpenMMLab. All rights reserved.
from os import path as osp

import mmcv
import torch
from mmcv.parallel import DataContainer as DC
from mmcv.runner import auto_fp16

from mmdet3d.core import Box3DMode, Coord3DMode, show_result
from mmdet.models.detectors import BaseDetector


[docs]class Base3DDetector(BaseDetector): """Base class for detectors."""
[docs] def forward_test(self, points, img_metas, img=None, **kwargs): """ Args: points (list[torch.Tensor]): the outer list indicates test-time augmentations and inner torch.Tensor should have a shape NxC, which contains all points in the batch. img_metas (list[list[dict]]): the outer list indicates test-time augs (multiscale, flip, etc.) and the inner list indicates images in a batch img (list[torch.Tensor], optional): the outer list indicates test-time augmentations and inner torch.Tensor should have a shape NxCxHxW, which contains all images in the batch. Defaults to None. """ for var, name in [(points, 'points'), (img_metas, 'img_metas')]: if not isinstance(var, list): raise TypeError('{} must be a list, but got {}'.format( name, type(var))) num_augs = len(points) if num_augs != len(img_metas): raise ValueError( 'num of augmentations ({}) != num of image meta ({})'.format( len(points), len(img_metas))) if num_augs == 1: img = [img] if img is None else img return self.simple_test(points[0], img_metas[0], img[0], **kwargs) else: return self.aug_test(points, img_metas, img, **kwargs)
[docs] @auto_fp16(apply_to=('img', 'points')) def forward(self, return_loss=True, **kwargs): """Calls either forward_train or forward_test depending on whether return_loss=True. Note this setting will change the expected inputs. When `return_loss=True`, img and img_metas are single-nested (i.e. torch.Tensor and list[dict]), and when `resturn_loss=False`, img and img_metas should be double nested (i.e. list[torch.Tensor], list[list[dict]]), with the outer list indicating test time augmentations. """ if return_loss: return self.forward_train(**kwargs) else: return self.forward_test(**kwargs)
[docs] def show_results(self, data, result, out_dir, show=False, score_thr=None): """Results visualization. Args: data (list[dict]): Input points and the information of the sample. result (list[dict]): Prediction results. out_dir (str): Output directory of visualization result. show (bool, optional): Determines whether you are going to show result by open3d. Defaults to False. score_thr (float, optional): Score threshold of bounding boxes. Default to None. """ for batch_id in range(len(result)): if isinstance(data['points'][0], DC): points = data['points'][0]._data[0][batch_id].numpy() elif mmcv.is_list_of(data['points'][0], torch.Tensor): points = data['points'][0][batch_id] else: ValueError(f"Unsupported data type {type(data['points'][0])} " f'for visualization!') if isinstance(data['img_metas'][0], DC): pts_filename = data['img_metas'][0]._data[0][batch_id][ 'pts_filename'] box_mode_3d = data['img_metas'][0]._data[0][batch_id][ 'box_mode_3d'] elif mmcv.is_list_of(data['img_metas'][0], dict): pts_filename = data['img_metas'][0][batch_id]['pts_filename'] box_mode_3d = data['img_metas'][0][batch_id]['box_mode_3d'] else: ValueError( f"Unsupported data type {type(data['img_metas'][0])} " f'for visualization!') file_name = osp.split(pts_filename)[-1].split('.')[0] assert out_dir is not None, 'Expect out_dir, got none.' pred_bboxes = result[batch_id]['boxes_3d'] pred_labels = result[batch_id]['labels_3d'] if score_thr is not None: mask = result[batch_id]['scores_3d'] > score_thr pred_bboxes = pred_bboxes[mask] pred_labels = pred_labels[mask] # for now we convert points and bbox into depth mode if (box_mode_3d == Box3DMode.CAM) or (box_mode_3d == Box3DMode.LIDAR): points = Coord3DMode.convert_point(points, Coord3DMode.LIDAR, Coord3DMode.DEPTH) pred_bboxes = Box3DMode.convert(pred_bboxes, box_mode_3d, Box3DMode.DEPTH) elif box_mode_3d != Box3DMode.DEPTH: ValueError( f'Unsupported box_mode_3d {box_mode_3d} for conversion!') pred_bboxes = pred_bboxes.tensor.cpu().numpy() show_result( points, None, pred_bboxes, out_dir, file_name, show=show, pred_labels=pred_labels)
Read the Docs v: dev
Versions
latest
stable
v1.0.0rc1
v1.0.0rc0
v0.18.1
v0.18.0
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.