From cac4138f99790401c362847d8da7fb8f5a30cc24 Mon Sep 17 00:00:00 2001 From: sshuair <sshuair@gmail.com> Date: Tue, 22 Sep 2020 01:04:46 +0800 Subject: [PATCH] fix acc and iou compute nan problem (#116) * fix acc and iou compute nan problem * fix acc and iou compute nan problem * add nan_to_num args for mean_iou * add nan_to_num args for mean_iou * add nan_to_num args for mean_iou * add nan_to_num args for mean_iou * add nan_to_num args for mean_iou * Update mmseg/core/evaluation/mean_iou.py * Update mean_iou.py * Update mean_iou.py Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com> --- mmseg/core/evaluation/mean_iou.py | 8 ++++++-- tests/test_mean_iou.py | 7 +++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/mmseg/core/evaluation/mean_iou.py b/mmseg/core/evaluation/mean_iou.py index f0b4234f..301cfd04 100644 --- a/mmseg/core/evaluation/mean_iou.py +++ b/mmseg/core/evaluation/mean_iou.py @@ -34,7 +34,7 @@ def intersect_and_union(pred_label, label, num_classes, ignore_index): return area_intersect, area_union, area_pred_label, area_label -def mean_iou(results, gt_seg_maps, num_classes, ignore_index): +def mean_iou(results, gt_seg_maps, num_classes, ignore_index, nan_to_num=None): """Calculate Intersection and Union (IoU) Args: @@ -42,6 +42,8 @@ def mean_iou(results, gt_seg_maps, num_classes, ignore_index): gt_seg_maps (list[ndarray]): list of ground truth segmentation maps num_classes (int): Number of categories ignore_index (int): Index that will be ignored in evaluation. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. Returns: float: Overall accuracy on all images. @@ -66,5 +68,7 @@ def mean_iou(results, gt_seg_maps, num_classes, ignore_index): all_acc = total_area_intersect.sum() / total_area_label.sum() acc = total_area_intersect / total_area_label iou = total_area_intersect / total_area_union - + if nan_to_num is not None: + return all_acc, np.nan_to_num(acc, nan=nan_to_num), \ + np.nan_to_num(iou, nan=nan_to_num) return all_acc, acc, iou diff --git a/tests/test_mean_iou.py b/tests/test_mean_iou.py index 48a3df8e..74a2b786 100644 --- a/tests/test_mean_iou.py +++ b/tests/test_mean_iou.py @@ -54,3 +54,10 @@ def test_mean_iou(): assert all_acc == all_acc_l assert np.allclose(acc, acc_l) assert np.allclose(iou, iou_l) + + results = np.random.randint(0, 5, size=pred_size) + label = np.random.randint(0, 4, size=pred_size) + all_acc, acc, iou = mean_iou( + results, label, num_classes, ignore_index=255, nan_to_num=-1) + assert acc[-1] == -1 + assert iou[-1] == -1 -- GitLab