Shortcuts

Source code for mmdet3d.core.evaluation.instance_seg_eval

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmcv.utils import print_log
from terminaltables import AsciiTable

from .scannet_utils.evaluate_semantic_instance import scannet_eval


def aggregate_predictions(masks, labels, scores, valid_class_ids):
    """Maps predictions to ScanNet evaluator format.

    Args:
        masks (list[torch.Tensor]): Per scene predicted instance masks.
        labels (list[torch.Tensor]): Per scene predicted instance labels.
        scores (list[torch.Tensor]): Per scene predicted instance scores.
        valid_class_ids (tuple[int]): Ids of valid categories.

    Returns:
        list[dict]: Per scene aggregated predictions.
    """
    infos = []
    for id, (mask, label, score) in enumerate(zip(masks, labels, scores)):
        mask = mask.clone().numpy()
        label = label.clone().numpy()
        score = score.clone().numpy()
        info = dict()
        n_instances = mask.max() + 1
        for i in range(n_instances):
            # match pred_instance['filename'] from assign_instances_for_scan
            file_name = f'{id}_{i}'
            info[file_name] = dict()
            info[file_name]['mask'] = (mask == i).astype(np.int)
            info[file_name]['label_id'] = valid_class_ids[label[i]]
            info[file_name]['conf'] = score[i]
        infos.append(info)
    return infos


def rename_gt(gt_semantic_masks, gt_instance_masks, valid_class_ids):
    """Maps gt instance and semantic masks to instance masks for ScanNet
    evaluator.

    Args:
        gt_semantic_masks (list[torch.Tensor]): Per scene gt semantic masks.
        gt_instance_masks (list[torch.Tensor]): Per scene gt instance masks.
        valid_class_ids (tuple[int]): Ids of valid categories.

    Returns:
        list[np.array]: Per scene instance masks.
    """
    renamed_instance_masks = []
    for semantic_mask, instance_mask in zip(gt_semantic_masks,
                                            gt_instance_masks):
        semantic_mask = semantic_mask.clone().numpy()
        instance_mask = instance_mask.clone().numpy()
        unique = np.unique(instance_mask)
        assert len(unique) < 1000
        for i in unique:
            semantic_instance = semantic_mask[instance_mask == i]
            semantic_unique = np.unique(semantic_instance)
            assert len(semantic_unique) == 1
            if semantic_unique[0] < len(valid_class_ids):
                instance_mask[
                    instance_mask ==
                    i] = 1000 * valid_class_ids[semantic_unique[0]] + i
        renamed_instance_masks.append(instance_mask)
    return renamed_instance_masks


[docs]def instance_seg_eval(gt_semantic_masks, gt_instance_masks, pred_instance_masks, pred_instance_labels, pred_instance_scores, valid_class_ids, class_labels, options=None, logger=None): """Instance Segmentation Evaluation. Evaluate the result of the instance segmentation. Args: gt_semantic_masks (list[torch.Tensor]): Ground truth semantic masks. gt_instance_masks (list[torch.Tensor]): Ground truth instance masks. pred_instance_masks (list[torch.Tensor]): Predicted instance masks. pred_instance_labels (list[torch.Tensor]): Predicted instance labels. pred_instance_scores (list[torch.Tensor]): Predicted instance labels. valid_class_ids (tuple[int]): Ids of valid categories. class_labels (tuple[str]): Names of valid categories. options (dict, optional): Additional options. Keys may contain: `overlaps`, `min_region_sizes`, `distance_threshes`, `distance_confs`. Default: None. logger (logging.Logger | str, optional): The way to print the mAP summary. See `mmdet.utils.print_log()` for details. Default: None. Returns: dict[str, float]: Dict of results. """ assert len(valid_class_ids) == len(class_labels) id_to_label = { valid_class_ids[i]: class_labels[i] for i in range(len(valid_class_ids)) } preds = aggregate_predictions( masks=pred_instance_masks, labels=pred_instance_labels, scores=pred_instance_scores, valid_class_ids=valid_class_ids) gts = rename_gt(gt_semantic_masks, gt_instance_masks, valid_class_ids) metrics = scannet_eval( preds=preds, gts=gts, options=options, valid_class_ids=valid_class_ids, class_labels=class_labels, id_to_label=id_to_label) header = ['classes', 'AP_0.25', 'AP_0.50', 'AP'] rows = [] for label, data in metrics['classes'].items(): aps = [data['ap25%'], data['ap50%'], data['ap']] rows.append([label] + [f'{ap:.4f}' for ap in aps]) aps = metrics['all_ap_25%'], metrics['all_ap_50%'], metrics['all_ap'] footer = ['Overall'] + [f'{ap:.4f}' for ap in aps] table = AsciiTable([header] + rows + [footer]) table.inner_footing_row_border = True print_log('\n' + table.table, logger=logger) return metrics