import numpy as np
from mmcv.utils import print_log
from terminaltables import AsciiTable
def fast_hist(preds, labels, num_classes):
"""Compute the confusion matrix for every batch.
Args:
preds (np.ndarray): Prediction labels of points with shape of
(num_points, ).
labels (np.ndarray): Ground truth labels of points with shape of
(num_points, ).
num_classes (int): number of classes
Returns:
np.ndarray: Calculated confusion matrix.
"""
k = (labels >= 0) & (labels < num_classes)
bin_count = np.bincount(
num_classes * labels[k].astype(int) + preds[k],
minlength=num_classes**2)
return bin_count[:num_classes**2].reshape(num_classes, num_classes)
def per_class_iou(hist):
"""Compute the per class iou.
Args:
hist(np.ndarray): Overall confusion martix
(num_classes, num_classes ).
Returns:
np.ndarray: Calculated per class iou
"""
return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
def get_acc(hist):
"""Compute the overall accuracy.
Args:
hist(np.ndarray): Overall confusion martix
(num_classes, num_classes ).
Returns:
float: Calculated overall acc
"""
return np.diag(hist).sum() / hist.sum()
def get_acc_cls(hist):
"""Compute the class average accuracy.
Args:
hist(np.ndarray): Overall confusion martix
(num_classes, num_classes ).
Returns:
float: Calculated class average acc
"""
return np.nanmean(np.diag(hist) / hist.sum(axis=1))
[docs]def seg_eval(gt_labels, seg_preds, label2cat, ignore_index, logger=None):
"""Semantic Segmentation Evaluation.
Evaluate the result of the Semantic Segmentation.
Args:
gt_labels (list[torch.Tensor]): Ground truth labels.
seg_preds (list[torch.Tensor]): Predictions.
label2cat (dict): Map from label to category name.
ignore_index (int): Index that will be ignored in evaluation.
logger (logging.Logger | str | None): The way to print the mAP
summary. See `mmdet.utils.print_log()` for details. Default: None.
Returns:
dict[str, float]: Dict of results.
"""
assert len(seg_preds) == len(gt_labels)
num_classes = len(label2cat)
hist_list = []
for i in range(len(gt_labels)):
gt_seg = gt_labels[i].clone().numpy().astype(np.int)
pred_seg = seg_preds[i].clone().numpy().astype(np.int)
# filter out ignored points
pred_seg[gt_seg == ignore_index] = -1
gt_seg[gt_seg == ignore_index] = -1
# calculate one instance result
hist_list.append(fast_hist(pred_seg, gt_seg, num_classes))
iou = per_class_iou(sum(hist_list))
miou = np.nanmean(iou)
acc = get_acc(sum(hist_list))
acc_cls = get_acc_cls(sum(hist_list))
header = ['classes']
for i in range(len(label2cat)):
header.append(label2cat[i])
header.extend(['miou', 'acc', 'acc_cls'])
ret_dict = dict()
table_columns = [['results']]
for i in range(len(label2cat)):
ret_dict[label2cat[i]] = float(iou[i])
table_columns.append([f'{iou[i]:.4f}'])
ret_dict['miou'] = float(miou)
ret_dict['acc'] = float(acc)
ret_dict['acc_cls'] = float(acc_cls)
table_columns.append([f'{miou:.4f}'])
table_columns.append([f'{acc:.4f}'])
table_columns.append([f'{acc_cls:.4f}'])
table_data = [header]
table_rows = list(zip(*table_columns))
table_data += table_rows
table = AsciiTable(table_data)
table.inner_footing_row_border = True
print_log('\n' + table.table, logger=logger)
return ret_dict