Shortcuts

Source code for mmdet3d.core.evaluation.seg_eval

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmcv.utils import print_log
from terminaltables import AsciiTable


def fast_hist(preds, labels, num_classes):
    """Compute the confusion matrix for every batch.

    Args:
        preds (np.ndarray):  Prediction labels of points with shape of
        (num_points, ).
        labels (np.ndarray): Ground truth labels of points with shape of
        (num_points, ).
        num_classes (int): number of classes

    Returns:
        np.ndarray: Calculated confusion matrix.
    """

    k = (labels >= 0) & (labels < num_classes)
    bin_count = np.bincount(
        num_classes * labels[k].astype(int) + preds[k],
        minlength=num_classes**2)
    return bin_count[:num_classes**2].reshape(num_classes, num_classes)


def per_class_iou(hist):
    """Compute the per class iou.

    Args:
        hist(np.ndarray):  Overall confusion martix
        (num_classes, num_classes ).

    Returns:
        np.ndarray: Calculated per class iou
    """

    return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))


def get_acc(hist):
    """Compute the overall accuracy.

    Args:
        hist(np.ndarray):  Overall confusion martix
        (num_classes, num_classes ).

    Returns:
        float: Calculated overall acc
    """

    return np.diag(hist).sum() / hist.sum()


def get_acc_cls(hist):
    """Compute the class average accuracy.

    Args:
        hist(np.ndarray):  Overall confusion martix
        (num_classes, num_classes ).

    Returns:
        float: Calculated class average acc
    """

    return np.nanmean(np.diag(hist) / hist.sum(axis=1))


[docs]def seg_eval(gt_labels, seg_preds, label2cat, ignore_index, logger=None): """Semantic Segmentation Evaluation. Evaluate the result of the Semantic Segmentation. Args: gt_labels (list[torch.Tensor]): Ground truth labels. seg_preds (list[torch.Tensor]): Predictions. label2cat (dict): Map from label to category name. ignore_index (int): Index that will be ignored in evaluation. logger (logging.Logger | str, optional): The way to print the mAP summary. See `mmdet.utils.print_log()` for details. Default: None. Returns: dict[str, float]: Dict of results. """ assert len(seg_preds) == len(gt_labels) num_classes = len(label2cat) hist_list = [] for i in range(len(gt_labels)): gt_seg = gt_labels[i].clone().numpy().astype(np.int) pred_seg = seg_preds[i].clone().numpy().astype(np.int) # filter out ignored points pred_seg[gt_seg == ignore_index] = -1 gt_seg[gt_seg == ignore_index] = -1 # calculate one instance result hist_list.append(fast_hist(pred_seg, gt_seg, num_classes)) iou = per_class_iou(sum(hist_list)) miou = np.nanmean(iou) acc = get_acc(sum(hist_list)) acc_cls = get_acc_cls(sum(hist_list)) header = ['classes'] for i in range(len(label2cat)): header.append(label2cat[i]) header.extend(['miou', 'acc', 'acc_cls']) ret_dict = dict() table_columns = [['results']] for i in range(len(label2cat)): ret_dict[label2cat[i]] = float(iou[i]) table_columns.append([f'{iou[i]:.4f}']) ret_dict['miou'] = float(miou) ret_dict['acc'] = float(acc) ret_dict['acc_cls'] = float(acc_cls) table_columns.append([f'{miou:.4f}']) table_columns.append([f'{acc:.4f}']) table_columns.append([f'{acc_cls:.4f}']) table_data = [header] table_rows = list(zip(*table_columns)) table_data += table_rows table = AsciiTable(table_data) table.inner_footing_row_border = True print_log('\n' + table.table, logger=logger) return ret_dict
Read the Docs v: dev
Versions
latest
stable
v1.0.0rc1
v1.0.0rc0
v0.18.1
v0.18.0
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.