From 578d4d0c42422e916caff82f6ecb2912c8c06bf1 Mon Sep 17 00:00:00 2001 From: wangbin <40346249+Dawn-bin@users.noreply.github.com> Date: Fri, 29 Apr 2022 19:32:12 +0800 Subject: [PATCH] [Fix] Fix the bug in binary_cross_entropy (#1527) * [Fix] Fix the bug in binary_cross_entropy Fix the bug in binary_cross_entropy 'label.max() <= 1' should mask out ignore_index, since the ignore_index often set as 255. * [Fix] Fix the bug in binary_cross_entropy, add comments As the ignore_index often set as 255, so the binary class label check should mask out ignore_index. Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> * [Fix] Fix the bug in binary_cross_entropy As the ignore_index often set as 255, so the binary class label check should mask out ignore_index. Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> Co-authored-by: MeowZheng <meowzheng@outlook.com> --- mmseg/models/losses/cross_entropy_loss.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mmseg/models/losses/cross_entropy_loss.py b/mmseg/models/losses/cross_entropy_loss.py index e01ddea9..623fd58d 100644 --- a/mmseg/models/losses/cross_entropy_loss.py +++ b/mmseg/models/losses/cross_entropy_loss.py @@ -118,7 +118,10 @@ def binary_cross_entropy(pred, if pred.size(1) == 1: # For binary class segmentation, the shape of pred is # [N, 1, H, W] and that of label is [N, H, W]. - assert label.max() <= 1, \ + # As the ignore_index often set as 255, so the + # binary class label check should mask out + # ignore_index + assert label[label != ignore_index].max() <= 1, \ 'For pred with shape [N, 1, H, W], its label must have at ' \ 'most 2 classes' pred = pred.squeeze() -- GitLab