From 23b7cfc59d09e4736eda115af47161a5a698dc99 Mon Sep 17 00:00:00 2001
From: Rockey <41846794+RockeyCoss@users.noreply.github.com>
Date: Thu, 14 Apr 2022 11:26:02 +0800
Subject: [PATCH] [Fix] Support single cahnnel `pred` for Binary Cross Entropy
 Loss (#1454)

* [Fix] Fix the bug that binary cross entropy loss doesn't support single channel input

* imcrease coverage

* modify implementation

* increase coverage

* add assert

* modify implementation

* enshollow condition judge

* fix
---
 mmseg/models/losses/cross_entropy_loss.py     | 11 ++-
 tests/test_models/test_losses/test_ce_loss.py | 69 +++++++++++++++++++
 2 files changed, 78 insertions(+), 2 deletions(-)

diff --git a/mmseg/models/losses/cross_entropy_loss.py b/mmseg/models/losses/cross_entropy_loss.py
index 7c2158f8..e01ddea9 100644
--- a/mmseg/models/losses/cross_entropy_loss.py
+++ b/mmseg/models/losses/cross_entropy_loss.py
@@ -81,7 +81,7 @@ def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
         bin_label_weights = valid_mask
     else:
         bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
-        bin_label_weights *= valid_mask
+        bin_label_weights = bin_label_weights * valid_mask
 
     return bin_labels, bin_label_weights, valid_mask
 
@@ -115,6 +115,13 @@ def binary_cross_entropy(pred,
     Returns:
         torch.Tensor: The calculated loss
     """
+    if pred.size(1) == 1:
+        # For binary class segmentation, the shape of pred is
+        # [N, 1, H, W] and that of label is [N, H, W].
+        assert label.max() <= 1, \
+            'For pred with shape [N, 1, H, W], its label must have at ' \
+            'most 2 classes'
+        pred = pred.squeeze()
     if pred.dim() != label.dim():
         assert (pred.dim() == 2 and label.dim() == 1) or (
                 pred.dim() == 4 and label.dim() == 3), \
@@ -128,7 +135,7 @@ def binary_cross_entropy(pred,
         # should mask out the ignored elements
         valid_mask = ((label >= 0) & (label != ignore_index)).float()
         if weight is not None:
-            weight *= valid_mask
+            weight = weight * valid_mask
         else:
             weight = valid_mask
     # average loss over non-ignored and valid elements
diff --git a/tests/test_models/test_losses/test_ce_loss.py b/tests/test_models/test_losses/test_ce_loss.py
index 6fd8d25a..afa57064 100644
--- a/tests/test_models/test_losses/test_ce_loss.py
+++ b/tests/test_models/test_losses/test_ce_loss.py
@@ -85,6 +85,35 @@ def test_ce_loss(use_sigmoid, reduction, avg_non_ignore, bce_input_same_dim):
                 ignore_index=255) / fake_label.numel()
     assert torch.allclose(loss, torch_loss)
 
+    if use_sigmoid:
+        # test loss with complicated case for ce/bce
+        # when avg_non_ignore is False, `avg_factor` would not be calculated
+        fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
+        fake_label = torch.ones(2, 8, 8).long()
+        fake_label[:, 0, 0] = 255
+        fake_weight = torch.rand(2, 8, 8)
+
+        loss_cls = build_loss(loss_cls_cfg)
+        loss = loss_cls(
+            fake_pred, fake_label, weight=fake_weight, ignore_index=255)
+        if use_sigmoid:
+            fake_label, weight, valid_mask = _expand_onehot_labels(
+                labels=fake_label,
+                label_weights=None,
+                target_shape=fake_pred.shape,
+                ignore_index=255)
+            torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
+                fake_pred,
+                fake_label.float(),
+                reduction='none',
+                weight=fake_weight.unsqueeze(1).expand(fake_pred.shape))
+            if avg_non_ignore:
+                avg_factor = valid_mask.sum().item()
+                torch_loss = (torch_loss * weight).sum() / avg_factor
+            else:
+                torch_loss = (torch_loss * weight).mean()
+        assert torch.allclose(loss, torch_loss)
+
     # test loss with class weights from file
     fake_pred = torch.Tensor([[100, -100]])
     fake_label = torch.Tensor([1]).long()
@@ -223,3 +252,43 @@ def test_ce_loss(use_sigmoid, reduction, avg_non_ignore, bce_input_same_dim):
                 reduction='sum',
                 weight=class_weight) / fake_label.numel()
     assert torch.allclose(loss, torch_loss)
+
+
+@pytest.mark.parametrize('avg_non_ignore', [True, False])
+@pytest.mark.parametrize('with_weight', [True, False])
+def test_binary_class_ce_loss(avg_non_ignore, with_weight):
+    from mmseg.models import build_loss
+
+    fake_pred = torch.rand(3, 1, 10, 10)
+    fake_label = torch.randint(0, 2, (3, 10, 10))
+    fake_weight = torch.rand(3, 10, 10)
+    valid_mask = ((fake_label >= 0) & (fake_label != 255)).float()
+    weight = valid_mask
+
+    torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
+        fake_pred,
+        fake_label.unsqueeze(1).float(),
+        reduction='none',
+        weight=fake_weight.unsqueeze(1).float() if with_weight else None)
+    if avg_non_ignore:
+        eps = torch.finfo(torch.float32).eps
+        avg_factor = valid_mask.sum().item()
+        torch_loss = (torch_loss * weight.unsqueeze(1)).sum() / (
+            avg_factor + eps)
+    else:
+        torch_loss = (torch_loss * weight.unsqueeze(1)).mean()
+
+    loss_cls_cfg = dict(
+        type='CrossEntropyLoss',
+        use_sigmoid=True,
+        loss_weight=1.0,
+        avg_non_ignore=avg_non_ignore,
+        reduction='mean',
+        loss_name='loss_ce')
+    loss_cls = build_loss(loss_cls_cfg)
+    loss = loss_cls(
+        fake_pred,
+        fake_label,
+        weight=fake_weight if with_weight else None,
+        ignore_index=255)
+    assert torch.allclose(loss, torch_loss)
-- 
GitLab