diff --git a/mmseg/models/losses/cross_entropy_loss.py b/mmseg/models/losses/cross_entropy_loss.py index e01ddea9c7f040b74895378ad58ff6039a553d1c..623fd58dbc7d909962f00d85517720ec732c6ff2 100644 --- a/mmseg/models/losses/cross_entropy_loss.py +++ b/mmseg/models/losses/cross_entropy_loss.py @@ -118,7 +118,10 @@ def binary_cross_entropy(pred, if pred.size(1) == 1: # For binary class segmentation, the shape of pred is # [N, 1, H, W] and that of label is [N, H, W]. - assert label.max() <= 1, \ + # As the ignore_index often set as 255, so the + # binary class label check should mask out + # ignore_index + assert label[label != ignore_index].max() <= 1, \ 'For pred with shape [N, 1, H, W], its label must have at ' \ 'most 2 classes' pred = pred.squeeze()