From 23b7cfc59d09e4736eda115af47161a5a698dc99 Mon Sep 17 00:00:00 2001 From: Rockey <41846794+RockeyCoss@users.noreply.github.com> Date: Thu, 14 Apr 2022 11:26:02 +0800 Subject: [PATCH] [Fix] Support single cahnnel `pred` for Binary Cross Entropy Loss (#1454) * [Fix] Fix the bug that binary cross entropy loss doesn't support single channel input * imcrease coverage * modify implementation * increase coverage * add assert * modify implementation * enshollow condition judge * fix --- mmseg/models/losses/cross_entropy_loss.py | 11 ++- tests/test_models/test_losses/test_ce_loss.py | 69 +++++++++++++++++++ 2 files changed, 78 insertions(+), 2 deletions(-) diff --git a/mmseg/models/losses/cross_entropy_loss.py b/mmseg/models/losses/cross_entropy_loss.py index 7c2158f8..e01ddea9 100644 --- a/mmseg/models/losses/cross_entropy_loss.py +++ b/mmseg/models/losses/cross_entropy_loss.py @@ -81,7 +81,7 @@ def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index): bin_label_weights = valid_mask else: bin_label_weights = label_weights.unsqueeze(1).expand(target_shape) - bin_label_weights *= valid_mask + bin_label_weights = bin_label_weights * valid_mask return bin_labels, bin_label_weights, valid_mask @@ -115,6 +115,13 @@ def binary_cross_entropy(pred, Returns: torch.Tensor: The calculated loss """ + if pred.size(1) == 1: + # For binary class segmentation, the shape of pred is + # [N, 1, H, W] and that of label is [N, H, W]. + assert label.max() <= 1, \ + 'For pred with shape [N, 1, H, W], its label must have at ' \ + 'most 2 classes' + pred = pred.squeeze() if pred.dim() != label.dim(): assert (pred.dim() == 2 and label.dim() == 1) or ( pred.dim() == 4 and label.dim() == 3), \ @@ -128,7 +135,7 @@ def binary_cross_entropy(pred, # should mask out the ignored elements valid_mask = ((label >= 0) & (label != ignore_index)).float() if weight is not None: - weight *= valid_mask + weight = weight * valid_mask else: weight = valid_mask # average loss over non-ignored and valid elements diff --git a/tests/test_models/test_losses/test_ce_loss.py b/tests/test_models/test_losses/test_ce_loss.py index 6fd8d25a..afa57064 100644 --- a/tests/test_models/test_losses/test_ce_loss.py +++ b/tests/test_models/test_losses/test_ce_loss.py @@ -85,6 +85,35 @@ def test_ce_loss(use_sigmoid, reduction, avg_non_ignore, bce_input_same_dim): ignore_index=255) / fake_label.numel() assert torch.allclose(loss, torch_loss) + if use_sigmoid: + # test loss with complicated case for ce/bce + # when avg_non_ignore is False, `avg_factor` would not be calculated + fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5) + fake_label = torch.ones(2, 8, 8).long() + fake_label[:, 0, 0] = 255 + fake_weight = torch.rand(2, 8, 8) + + loss_cls = build_loss(loss_cls_cfg) + loss = loss_cls( + fake_pred, fake_label, weight=fake_weight, ignore_index=255) + if use_sigmoid: + fake_label, weight, valid_mask = _expand_onehot_labels( + labels=fake_label, + label_weights=None, + target_shape=fake_pred.shape, + ignore_index=255) + torch_loss = torch.nn.functional.binary_cross_entropy_with_logits( + fake_pred, + fake_label.float(), + reduction='none', + weight=fake_weight.unsqueeze(1).expand(fake_pred.shape)) + if avg_non_ignore: + avg_factor = valid_mask.sum().item() + torch_loss = (torch_loss * weight).sum() / avg_factor + else: + torch_loss = (torch_loss * weight).mean() + assert torch.allclose(loss, torch_loss) + # test loss with class weights from file fake_pred = torch.Tensor([[100, -100]]) fake_label = torch.Tensor([1]).long() @@ -223,3 +252,43 @@ def test_ce_loss(use_sigmoid, reduction, avg_non_ignore, bce_input_same_dim): reduction='sum', weight=class_weight) / fake_label.numel() assert torch.allclose(loss, torch_loss) + + +@pytest.mark.parametrize('avg_non_ignore', [True, False]) +@pytest.mark.parametrize('with_weight', [True, False]) +def test_binary_class_ce_loss(avg_non_ignore, with_weight): + from mmseg.models import build_loss + + fake_pred = torch.rand(3, 1, 10, 10) + fake_label = torch.randint(0, 2, (3, 10, 10)) + fake_weight = torch.rand(3, 10, 10) + valid_mask = ((fake_label >= 0) & (fake_label != 255)).float() + weight = valid_mask + + torch_loss = torch.nn.functional.binary_cross_entropy_with_logits( + fake_pred, + fake_label.unsqueeze(1).float(), + reduction='none', + weight=fake_weight.unsqueeze(1).float() if with_weight else None) + if avg_non_ignore: + eps = torch.finfo(torch.float32).eps + avg_factor = valid_mask.sum().item() + torch_loss = (torch_loss * weight.unsqueeze(1)).sum() / ( + avg_factor + eps) + else: + torch_loss = (torch_loss * weight.unsqueeze(1)).mean() + + loss_cls_cfg = dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0, + avg_non_ignore=avg_non_ignore, + reduction='mean', + loss_name='loss_ce') + loss_cls = build_loss(loss_cls_cfg) + loss = loss_cls( + fake_pred, + fake_label, + weight=fake_weight if with_weight else None, + ignore_index=255) + assert torch.allclose(loss, torch_loss) -- GitLab