From b5d8c7dc6c0141ac9908761dce93811a8f46e8cf Mon Sep 17 00:00:00 2001
From: Wang Xinjiang <swanxinjiang@gmail.com>
Date: Thu, 21 Apr 2022 10:18:37 +0800
Subject: [PATCH] Fix potential bugs in accuracy.py (#1496)

---
 mmseg/models/losses/accuracy.py | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/mmseg/models/losses/accuracy.py b/mmseg/models/losses/accuracy.py
index 28d55c4e..1d9e2d77 100644
--- a/mmseg/models/losses/accuracy.py
+++ b/mmseg/models/losses/accuracy.py
@@ -45,14 +45,18 @@ def accuracy(pred, target, topk=1, thresh=None, ignore_index=None):
     if thresh is not None:
         # Only prediction values larger than thresh are counted as correct
         correct = correct & (pred_value > thresh).t()
-    correct = correct[:, target != ignore_index]
+    if ignore_index is not None:
+        correct = correct[:, target != ignore_index]
     res = []
     eps = torch.finfo(torch.float32).eps
     for k in topk:
         # Avoid causing ZeroDivisionError when all pixels
         # of an image are ignored
         correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + eps
-        total_num = target[target != ignore_index].numel() + eps
+        if ignore_index is not None:
+            total_num = target[target != ignore_index].numel() + eps
+        else:
+            total_num = target.numel() + eps
         res.append(correct_k.mul_(100.0 / total_num))
     return res[0] if return_single else res
 
-- 
GitLab