From cac4138f99790401c362847d8da7fb8f5a30cc24 Mon Sep 17 00:00:00 2001
From: sshuair <sshuair@gmail.com>
Date: Tue, 22 Sep 2020 01:04:46 +0800
Subject: [PATCH] fix acc and iou compute nan problem (#116)

* fix acc and iou compute nan problem

* fix acc and iou compute nan problem

* add nan_to_num args for mean_iou

* add nan_to_num args for mean_iou

* add nan_to_num args for mean_iou

* add nan_to_num args for mean_iou

* add nan_to_num args for mean_iou

* Update mmseg/core/evaluation/mean_iou.py

* Update mean_iou.py

* Update mean_iou.py

Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>
---
 mmseg/core/evaluation/mean_iou.py | 8 ++++++--
 tests/test_mean_iou.py            | 7 +++++++
 2 files changed, 13 insertions(+), 2 deletions(-)

diff --git a/mmseg/core/evaluation/mean_iou.py b/mmseg/core/evaluation/mean_iou.py
index f0b4234f..301cfd04 100644
--- a/mmseg/core/evaluation/mean_iou.py
+++ b/mmseg/core/evaluation/mean_iou.py
@@ -34,7 +34,7 @@ def intersect_and_union(pred_label, label, num_classes, ignore_index):
     return area_intersect, area_union, area_pred_label, area_label
 
 
-def mean_iou(results, gt_seg_maps, num_classes, ignore_index):
+def mean_iou(results, gt_seg_maps, num_classes, ignore_index, nan_to_num=None):
     """Calculate Intersection and Union (IoU)
 
     Args:
@@ -42,6 +42,8 @@ def mean_iou(results, gt_seg_maps, num_classes, ignore_index):
         gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
         num_classes (int): Number of categories
         ignore_index (int): Index that will be ignored in evaluation.
+        nan_to_num (int, optional): If specified, NaN values will be replaced
+            by the numbers defined by the user. Default: None.
 
      Returns:
          float: Overall accuracy on all images.
@@ -66,5 +68,7 @@ def mean_iou(results, gt_seg_maps, num_classes, ignore_index):
     all_acc = total_area_intersect.sum() / total_area_label.sum()
     acc = total_area_intersect / total_area_label
     iou = total_area_intersect / total_area_union
-
+    if nan_to_num is not None:
+        return all_acc, np.nan_to_num(acc, nan=nan_to_num), \
+            np.nan_to_num(iou, nan=nan_to_num)
     return all_acc, acc, iou
diff --git a/tests/test_mean_iou.py b/tests/test_mean_iou.py
index 48a3df8e..74a2b786 100644
--- a/tests/test_mean_iou.py
+++ b/tests/test_mean_iou.py
@@ -54,3 +54,10 @@ def test_mean_iou():
     assert all_acc == all_acc_l
     assert np.allclose(acc, acc_l)
     assert np.allclose(iou, iou_l)
+
+    results = np.random.randint(0, 5, size=pred_size)
+    label = np.random.randint(0, 4, size=pred_size)
+    all_acc, acc, iou = mean_iou(
+        results, label, num_classes, ignore_index=255, nan_to_num=-1)
+    assert acc[-1] == -1
+    assert iou[-1] == -1
-- 
GitLab