From b5d8c7dc6c0141ac9908761dce93811a8f46e8cf Mon Sep 17 00:00:00 2001 From: Wang Xinjiang <swanxinjiang@gmail.com> Date: Thu, 21 Apr 2022 10:18:37 +0800 Subject: [PATCH] Fix potential bugs in accuracy.py (#1496) --- mmseg/models/losses/accuracy.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mmseg/models/losses/accuracy.py b/mmseg/models/losses/accuracy.py index 28d55c4e..1d9e2d77 100644 --- a/mmseg/models/losses/accuracy.py +++ b/mmseg/models/losses/accuracy.py @@ -45,14 +45,18 @@ def accuracy(pred, target, topk=1, thresh=None, ignore_index=None): if thresh is not None: # Only prediction values larger than thresh are counted as correct correct = correct & (pred_value > thresh).t() - correct = correct[:, target != ignore_index] + if ignore_index is not None: + correct = correct[:, target != ignore_index] res = [] eps = torch.finfo(torch.float32).eps for k in topk: # Avoid causing ZeroDivisionError when all pixels # of an image are ignored correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + eps - total_num = target[target != ignore_index].numel() + eps + if ignore_index is not None: + total_num = target[target != ignore_index].numel() + eps + else: + total_num = target.numel() + eps res.append(correct_k.mul_(100.0 / total_num)) return res[0] if return_single else res -- GitLab