# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
import tempfile
import warnings
from os import path as osp
from torch.utils.data import Dataset
from mmdet.datasets import DATASETS
from ..core.bbox import get_box_type
from .pipelines import Compose
from .utils import extract_result_dict, get_loading_pipeline
[docs]@DATASETS.register_module()
class Custom3DDataset(Dataset):
"""Customized 3D dataset.
This is the base dataset of SUNRGB-D, ScanNet, nuScenes, and KITTI
dataset.
Args:
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None.
classes (tuple[str], optional): Classes used in the dataset.
Defaults to None.
modality (dict, optional): Modality to specify the sensor data used
as input. Defaults to None.
box_type_3d (str, optional): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`.
Defaults to 'LiDAR'. Available options includes
- 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Camera': Box in camera coordinates.
filter_empty_gt (bool, optional): Whether to filter empty GT.
Defaults to True.
test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False.
"""
def __init__(self,
data_root,
ann_file,
pipeline=None,
classes=None,
modality=None,
box_type_3d='LiDAR',
filter_empty_gt=True,
test_mode=False):
super().__init__()
self.data_root = data_root
self.ann_file = ann_file
self.test_mode = test_mode
self.modality = modality
self.filter_empty_gt = filter_empty_gt
self.box_type_3d, self.box_mode_3d = get_box_type(box_type_3d)
self.CLASSES = self.get_classes(classes)
self.cat2id = {name: i for i, name in enumerate(self.CLASSES)}
self.data_infos = self.load_annotations(self.ann_file)
if pipeline is not None:
self.pipeline = Compose(pipeline)
# set group flag for the sampler
if not self.test_mode:
self._set_group_flag()
[docs] def load_annotations(self, ann_file):
"""Load annotations from ann_file.
Args:
ann_file (str): Path of the annotation file.
Returns:
list[dict]: List of annotations.
"""
return mmcv.load(ann_file)
[docs] def get_data_info(self, index):
"""Get data info according to the given index.
Args:
index (int): Index of the sample data to get.
Returns:
dict: Data information that will be passed to the data \
preprocessing pipelines. It includes the following keys:
- sample_idx (str): Sample index.
- pts_filename (str): Filename of point clouds.
- file_name (str): Filename of point clouds.
- ann_info (dict): Annotation info.
"""
info = self.data_infos[index]
sample_idx = info['point_cloud']['lidar_idx']
pts_filename = osp.join(self.data_root, info['pts_path'])
input_dict = dict(
pts_filename=pts_filename,
sample_idx=sample_idx,
file_name=pts_filename)
if not self.test_mode:
annos = self.get_ann_info(index)
input_dict['ann_info'] = annos
if self.filter_empty_gt and ~(annos['gt_labels_3d'] != -1).any():
return None
return input_dict
[docs] def pre_pipeline(self, results):
"""Initialization before data preparation.
Args:
results (dict): Dict before data preprocessing.
- img_fields (list): Image fields.
- bbox3d_fields (list): 3D bounding boxes fields.
- pts_mask_fields (list): Mask fields of points.
- pts_seg_fields (list): Mask fields of point segments.
- bbox_fields (list): Fields of bounding boxes.
- mask_fields (list): Fields of masks.
- seg_fields (list): Segment fields.
- box_type_3d (str): 3D box type.
- box_mode_3d (str): 3D box mode.
"""
results['img_fields'] = []
results['bbox3d_fields'] = []
results['pts_mask_fields'] = []
results['pts_seg_fields'] = []
results['bbox_fields'] = []
results['mask_fields'] = []
results['seg_fields'] = []
results['box_type_3d'] = self.box_type_3d
results['box_mode_3d'] = self.box_mode_3d
[docs] def prepare_train_data(self, index):
"""Training data preparation.
Args:
index (int): Index for accessing the target data.
Returns:
dict: Training data dict of the corresponding index.
"""
input_dict = self.get_data_info(index)
if input_dict is None:
return None
self.pre_pipeline(input_dict)
example = self.pipeline(input_dict)
if self.filter_empty_gt and \
(example is None or
~(example['gt_labels_3d']._data != -1).any()):
return None
return example
[docs] def prepare_test_data(self, index):
"""Prepare data for testing.
Args:
index (int): Index for accessing the target data.
Returns:
dict: Testing data dict of the corresponding index.
"""
input_dict = self.get_data_info(index)
self.pre_pipeline(input_dict)
example = self.pipeline(input_dict)
return example
[docs] @classmethod
def get_classes(cls, classes=None):
"""Get class names of current dataset.
Args:
classes (Sequence[str] | str | None): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
Return:
list[str]: A list of class names.
"""
if classes is None:
return cls.CLASSES
if isinstance(classes, str):
# take it as a file path
class_names = mmcv.list_from_file(classes)
elif isinstance(classes, (tuple, list)):
class_names = classes
else:
raise ValueError(f'Unsupported type {type(classes)} of classes.')
return class_names
[docs] def evaluate(self,
results,
metric=None,
iou_thr=(0.25, 0.5),
logger=None,
show=False,
out_dir=None,
pipeline=None):
"""Evaluate.
Evaluation in indoor protocol.
Args:
results (list[dict]): List of results.
metric (str | list[str]): Metrics to be evaluated.
iou_thr (list[float]): AP IoU thresholds.
show (bool): Whether to visualize.
Default: False.
out_dir (str): Path to save the visualization results.
Default: None.
pipeline (list[dict], optional): raw data loading for showing.
Default: None.
Returns:
dict: Evaluation results.
"""
from mmdet3d.core.evaluation import indoor_eval
assert isinstance(
results, list), f'Expect results to be list, got {type(results)}.'
assert len(results) > 0, 'Expect length of results > 0.'
assert len(results) == len(self.data_infos)
assert isinstance(
results[0], dict
), f'Expect elements in results to be dict, got {type(results[0])}.'
gt_annos = [info['annos'] for info in self.data_infos]
label2cat = {i: cat_id for i, cat_id in enumerate(self.CLASSES)}
ret_dict = indoor_eval(
gt_annos,
results,
iou_thr,
label2cat,
logger=logger,
box_type_3d=self.box_type_3d,
box_mode_3d=self.box_mode_3d)
if show:
self.show(results, out_dir, pipeline=pipeline)
return ret_dict
def _build_default_pipeline(self):
"""Build the default pipeline for this dataset."""
raise NotImplementedError('_build_default_pipeline is not implemented '
f'for dataset {self.__class__.__name__}')
def _get_pipeline(self, pipeline):
"""Get data loading pipeline in self.show/evaluate function.
Args:
pipeline (list[dict] | None): Input pipeline. If None is given, \
get from self.pipeline.
"""
if pipeline is None:
if not hasattr(self, 'pipeline') or self.pipeline is None:
warnings.warn(
'Use default pipeline for data loading, this may cause '
'errors when data is on ceph')
return self._build_default_pipeline()
loading_pipeline = get_loading_pipeline(self.pipeline.transforms)
return Compose(loading_pipeline)
return Compose(pipeline)
def _extract_data(self, index, pipeline, key, load_annos=False):
"""Load data using input pipeline and extract data according to key.
Args:
index (int): Index for accessing the target data.
pipeline (:obj:`Compose`): Composed data loading pipeline.
key (str | list[str]): One single or a list of data key.
load_annos (bool): Whether to load data annotations.
If True, need to set self.test_mode as False before loading.
Returns:
np.ndarray | torch.Tensor | list[np.ndarray | torch.Tensor]:
A single or a list of loaded data.
"""
assert pipeline is not None, 'data loading pipeline is not provided'
# when we want to load ground-truth via pipeline (e.g. bbox, seg mask)
# we need to set self.test_mode as False so that we have 'annos'
if load_annos:
original_test_mode = self.test_mode
self.test_mode = False
input_dict = self.get_data_info(index)
self.pre_pipeline(input_dict)
example = pipeline(input_dict)
# extract data items according to keys
if isinstance(key, str):
data = extract_result_dict(example, key)
else:
data = [extract_result_dict(example, k) for k in key]
if load_annos:
self.test_mode = original_test_mode
return data
def __len__(self):
"""Return the length of data infos.
Returns:
int: Length of data infos.
"""
return len(self.data_infos)
def _rand_another(self, idx):
"""Randomly get another item with the same flag.
Returns:
int: Another index of item with the same flag.
"""
pool = np.where(self.flag == self.flag[idx])[0]
return np.random.choice(pool)
def __getitem__(self, idx):
"""Get item from infos according to the given index.
Returns:
dict: Data dictionary of the corresponding index.
"""
if self.test_mode:
return self.prepare_test_data(idx)
while True:
data = self.prepare_train_data(idx)
if data is None:
idx = self._rand_another(idx)
continue
return data
def _set_group_flag(self):
"""Set flag according to image aspect ratio.
Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0. In 3D datasets, they are all the same, thus are all
zeros.
"""
self.flag = np.zeros(len(self), dtype=np.uint8)