From 578d4d0c42422e916caff82f6ecb2912c8c06bf1 Mon Sep 17 00:00:00 2001
From: wangbin <40346249+Dawn-bin@users.noreply.github.com>
Date: Fri, 29 Apr 2022 19:32:12 +0800
Subject: [PATCH] [Fix] Fix the bug in binary_cross_entropy (#1527)

* [Fix] Fix the bug in binary_cross_entropy

 Fix the bug in binary_cross_entropy
'label.max() <= 1' should mask out ignore_index, since the ignore_index often set as 255.

* [Fix] Fix the bug in binary_cross_entropy, add comments

As the ignore_index often set as 255, so the binary class label check should mask out ignore_index.

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* [Fix] Fix the bug in binary_cross_entropy

As the ignore_index often set as 255, so the binary class label check should mask out ignore_index.

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
Co-authored-by: MeowZheng <meowzheng@outlook.com>
---
 mmseg/models/losses/cross_entropy_loss.py | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/mmseg/models/losses/cross_entropy_loss.py b/mmseg/models/losses/cross_entropy_loss.py
index e01ddea9..623fd58d 100644
--- a/mmseg/models/losses/cross_entropy_loss.py
+++ b/mmseg/models/losses/cross_entropy_loss.py
@@ -118,7 +118,10 @@ def binary_cross_entropy(pred,
     if pred.size(1) == 1:
         # For binary class segmentation, the shape of pred is
         # [N, 1, H, W] and that of label is [N, H, W].
-        assert label.max() <= 1, \
+        # As the ignore_index often set as 255, so the
+        # binary class label check should mask out
+        # ignore_index
+        assert label[label != ignore_index].max() <= 1, \
             'For pred with shape [N, 1, H, W], its label must have at ' \
             'most 2 classes'
         pred = pred.squeeze()
-- 
GitLab