Source code for mmdet3d.datasets.utils

import mmcv

# yapf: disable
from mmdet3d.datasets.pipelines import (Collect3D, DefaultFormatBundle3D,
                                        LoadAnnotations3D,
                                        LoadImageFromFileMono3D,
                                        LoadMultiViewImageFromFiles,
                                        LoadPointsFromFile,
                                        LoadPointsFromMultiSweeps,
                                        MultiScaleFlipAug3D,
                                        PointSegClassMapping)
# yapf: enable
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import LoadImageFromFile


def is_loading_function(transform):
    """Judge whether a transform function is a loading function.

    Note: `MultiScaleFlipAug3D` is a wrapper for multiple pipeline functions,
    so we need to search if its inner transforms contain any loading function.

    Args:
        transform (dict | :obj:`Pipeline`): A transform config or a function.

    Returns:
        bool | None: Whether it is a loading function. None means can't judge.
            When transform is `MultiScaleFlipAug3D`, we return None.
    """
    # TODO: use more elegant way to distinguish loading modules
    loading_functions = (LoadImageFromFile, LoadPointsFromFile,
                         LoadAnnotations3D, LoadMultiViewImageFromFiles,
                         LoadPointsFromMultiSweeps, DefaultFormatBundle3D,
                         Collect3D, LoadImageFromFileMono3D,
                         PointSegClassMapping)
    if isinstance(transform, dict):
        obj_cls = PIPELINES.get(transform['type'])
        if obj_cls is None:
            return False
        if obj_cls in loading_functions:
            return True
        if obj_cls in (MultiScaleFlipAug3D, ):
            return None
    elif callable(transform):
        if isinstance(transform, loading_functions):
            return True
        if isinstance(transform, MultiScaleFlipAug3D):
            return None
    return False


[docs]def get_loading_pipeline(pipeline): """Only keep loading image, points and annotations related configuration. Args: pipeline (list[dict] | list[:obj:`Pipeline`]): Data pipeline configs or list of pipeline functions. Returns: list[dict] | list[:obj:`Pipeline`]): The new pipeline list with only keep loading image, points and annotations related configuration. Examples: >>> pipelines = [ ... dict(type='LoadPointsFromFile', ... coord_type='LIDAR', load_dim=4, use_dim=4), ... dict(type='LoadImageFromFile'), ... dict(type='LoadAnnotations3D', ... with_bbox=True, with_label_3d=True), ... dict(type='Resize', ... img_scale=[(640, 192), (2560, 768)], keep_ratio=True), ... dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5), ... dict(type='PointsRangeFilter', ... point_cloud_range=point_cloud_range), ... dict(type='ObjectRangeFilter', ... point_cloud_range=point_cloud_range), ... dict(type='PointShuffle'), ... dict(type='Normalize', **img_norm_cfg), ... dict(type='Pad', size_divisor=32), ... dict(type='DefaultFormatBundle3D', class_names=class_names), ... dict(type='Collect3D', ... keys=['points', 'img', 'gt_bboxes_3d', 'gt_labels_3d']) ... ] >>> expected_pipelines = [ ... dict(type='LoadPointsFromFile', ... coord_type='LIDAR', load_dim=4, use_dim=4), ... dict(type='LoadImageFromFile'), ... dict(type='LoadAnnotations3D', ... with_bbox=True, with_label_3d=True), ... dict(type='DefaultFormatBundle3D', class_names=class_names), ... dict(type='Collect3D', ... keys=['points', 'img', 'gt_bboxes_3d', 'gt_labels_3d']) ... ] >>> assert expected_pipelines ==\ ... get_loading_pipeline(pipelines) """ loading_pipeline = [] for transform in pipeline: is_loading = is_loading_function(transform) if is_loading is None: # MultiScaleFlipAug3D # extract its inner pipeline if isinstance(transform, dict): inner_pipeline = transform.get('transforms', []) else: inner_pipeline = transform.transforms.transforms loading_pipeline.extend(get_loading_pipeline(inner_pipeline)) elif is_loading: loading_pipeline.append(transform) assert len(loading_pipeline) > 0, \ 'The data pipeline in your config file must include ' \ 'loading step.' return loading_pipeline
def extract_result_dict(results, key): """Extract and return the data corresponding to key in result dict. ``results`` is a dict output from `pipeline(input_dict)`, which is the loaded data from ``Dataset`` class. The data terms inside may be wrapped in list, tuple and DataContainer, so this function essentially extracts data from these wrappers. Args: results (dict): Data loaded using pipeline. key (str): Key of the desired data. Returns: np.ndarray | torch.Tensor | None: Data term. """ if key not in results.keys(): return None # results[key] may be data or list[data] or tuple[data] # data may be wrapped inside DataContainer data = results[key] if isinstance(data, (list, tuple)): data = data[0] if isinstance(data, mmcv.parallel.DataContainer): data = data._data return data