Skip to content
Snippets Groups Projects
Commit b2abe157 authored by zhengmiao's avatar zhengmiao
Browse files

Merge branch 'zhengmiao/tests_bp' into 'refactor_dev'

[Refactory] Clean UTs

See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!2
parents 24cc27dd 35c695bb
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
from mmseg.models.losses import Accuracy, reduce_loss, weight_reduce_loss
def test_weight_reduce_loss():
loss = torch.rand(1, 3, 4, 4)
weight = torch.zeros(1, 3, 4, 4)
weight[:, :, :2, :2] = 1
# test reduce_loss()
reduced = reduce_loss(loss, 'none')
assert reduced is loss
reduced = reduce_loss(loss, 'mean')
np.testing.assert_almost_equal(reduced.numpy(), loss.mean())
reduced = reduce_loss(loss, 'sum')
np.testing.assert_almost_equal(reduced.numpy(), loss.sum())
# test weight_reduce_loss()
reduced = weight_reduce_loss(loss, weight=None, reduction='none')
assert reduced is loss
reduced = weight_reduce_loss(loss, weight=weight, reduction='mean')
target = (loss * weight).mean()
np.testing.assert_almost_equal(reduced.numpy(), target)
reduced = weight_reduce_loss(loss, weight=weight, reduction='sum')
np.testing.assert_almost_equal(reduced.numpy(), (loss * weight).sum())
with pytest.raises(AssertionError):
weight_wrong = weight[0, 0, ...]
weight_reduce_loss(loss, weight=weight_wrong, reduction='mean')
with pytest.raises(AssertionError):
weight_wrong = weight[:, 0:2, ...]
weight_reduce_loss(loss, weight=weight_wrong, reduction='mean')
def test_accuracy():
# test for empty pred
pred = torch.empty(0, 4)
label = torch.empty(0)
accuracy = Accuracy(topk=1)
acc = accuracy(pred, label)
assert acc.item() == 0
pred = torch.Tensor([[0.2, 0.3, 0.6, 0.5], [0.1, 0.1, 0.2, 0.6],
[0.9, 0.0, 0.0, 0.1], [0.4, 0.7, 0.1, 0.1],
[0.0, 0.0, 0.99, 0]])
# test for ignore_index
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
accuracy = Accuracy(topk=1, ignore_index=None)
acc = accuracy(pred, true_label)
assert torch.allclose(acc, torch.tensor(100.0))
# test for ignore_index with a wrong prediction of that index
true_label = torch.Tensor([2, 3, 1, 1, 2]).long()
accuracy = Accuracy(topk=1, ignore_index=1)
acc = accuracy(pred, true_label)
assert torch.allclose(acc, torch.tensor(100.0))
# test for ignore_index 1 with a wrong prediction of other index
true_label = torch.Tensor([2, 0, 0, 1, 2]).long()
accuracy = Accuracy(topk=1, ignore_index=1)
acc = accuracy(pred, true_label)
assert torch.allclose(acc, torch.tensor(75.0))
# test for ignore_index 4 with a wrong prediction of other index
true_label = torch.Tensor([2, 0, 0, 1, 2]).long()
accuracy = Accuracy(topk=1, ignore_index=4)
acc = accuracy(pred, true_label)
assert torch.allclose(acc, torch.tensor(80.0))
# test for ignoring all the pixels
true_label = torch.Tensor([2, 2, 2, 2, 2]).long()
accuracy = Accuracy(topk=1, ignore_index=2)
acc = accuracy(pred, true_label)
assert torch.allclose(acc, torch.tensor(100.0))
# test for top1
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
accuracy = Accuracy(topk=1)
acc = accuracy(pred, true_label)
assert torch.allclose(acc, torch.tensor(100.0))
# test for top1 with score thresh=0.8
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
accuracy = Accuracy(topk=1, thresh=0.8)
acc = accuracy(pred, true_label)
assert torch.allclose(acc, torch.tensor(40.0))
# test for top2
accuracy = Accuracy(topk=2)
label = torch.Tensor([3, 2, 0, 0, 2]).long()
acc = accuracy(pred, label)
assert torch.allclose(acc, torch.tensor(100.0))
# test for both top1 and top2
accuracy = Accuracy(topk=(1, 2))
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
acc = accuracy(pred, true_label)
for a in acc:
assert torch.allclose(a, torch.tensor(100.0))
# topk is larger than pred class number
with pytest.raises(AssertionError):
accuracy = Accuracy(topk=5)
accuracy(pred, true_label)
# wrong topk type
with pytest.raises(AssertionError):
accuracy = Accuracy(topk='wrong type')
accuracy(pred, true_label)
# label size is larger than required
with pytest.raises(AssertionError):
label = torch.Tensor([2, 3, 0, 1, 2, 0]).long() # size mismatch
accuracy = Accuracy()
accuracy(pred, label)
# wrong pred dimension
with pytest.raises(AssertionError):
accuracy = Accuracy()
accuracy(pred[:, :, None], true_label)
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv import ConfigDict
from mmseg.models import build_segmentor
from .utils import _segmentor_forward_train_test
def test_cascade_encoder_decoder():
# test 1 decode head, w.o. aux head
cfg = ConfigDict(
type='CascadeEncoderDecoder',
num_stages=2,
backbone=dict(type='ExampleBackbone'),
decode_head=[
dict(type='ExampleDecodeHead'),
dict(type='ExampleCascadeDecodeHead')
])
cfg.test_cfg = ConfigDict(mode='whole')
segmentor = build_segmentor(cfg)
_segmentor_forward_train_test(segmentor)
# test slide mode
cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2))
segmentor = build_segmentor(cfg)
_segmentor_forward_train_test(segmentor)
# test 1 decode head, 1 aux head
cfg = ConfigDict(
type='CascadeEncoderDecoder',
num_stages=2,
backbone=dict(type='ExampleBackbone'),
decode_head=[
dict(type='ExampleDecodeHead'),
dict(type='ExampleCascadeDecodeHead')
],
auxiliary_head=dict(type='ExampleDecodeHead'))
cfg.test_cfg = ConfigDict(mode='whole')
segmentor = build_segmentor(cfg)
_segmentor_forward_train_test(segmentor)
# test 1 decode head, 2 aux head
cfg = ConfigDict(
type='CascadeEncoderDecoder',
num_stages=2,
backbone=dict(type='ExampleBackbone'),
decode_head=[
dict(type='ExampleDecodeHead'),
dict(type='ExampleCascadeDecodeHead')
],
auxiliary_head=[
dict(type='ExampleDecodeHead'),
dict(type='ExampleDecodeHead')
])
cfg.test_cfg = ConfigDict(mode='whole')
segmentor = build_segmentor(cfg)
_segmentor_forward_train_test(segmentor)
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv import ConfigDict
from mmseg.models import build_segmentor
from .utils import _segmentor_forward_train_test
def test_encoder_decoder():
# test 1 decode head, w.o. aux head
cfg = ConfigDict(
type='EncoderDecoder',
backbone=dict(type='ExampleBackbone'),
decode_head=dict(type='ExampleDecodeHead'),
train_cfg=None,
test_cfg=dict(mode='whole'))
segmentor = build_segmentor(cfg)
_segmentor_forward_train_test(segmentor)
# test slide mode
cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2))
segmentor = build_segmentor(cfg)
_segmentor_forward_train_test(segmentor)
# test 1 decode head, 1 aux head
cfg = ConfigDict(
type='EncoderDecoder',
backbone=dict(type='ExampleBackbone'),
decode_head=dict(type='ExampleDecodeHead'),
auxiliary_head=dict(type='ExampleDecodeHead'))
cfg.test_cfg = ConfigDict(mode='whole')
segmentor = build_segmentor(cfg)
_segmentor_forward_train_test(segmentor)
# test 1 decode head, 2 aux head
cfg = ConfigDict(
type='EncoderDecoder',
backbone=dict(type='ExampleBackbone'),
decode_head=dict(type='ExampleDecodeHead'),
auxiliary_head=[
dict(type='ExampleDecodeHead'),
dict(type='ExampleDecodeHead')
])
cfg.test_cfg = ConfigDict(mode='whole')
segmentor = build_segmentor(cfg)
_segmentor_forward_train_test(segmentor)
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from torch import nn
from mmseg.models import BACKBONES, HEADS
from mmseg.models.decode_heads.cascade_decode_head import BaseCascadeDecodeHead
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
def _demo_mm_inputs(input_shape=(1, 3, 8, 16), num_classes=10):
"""Create a superset of inputs needed to run test or train batches.
Args:
input_shape (tuple):
input batch dimensions
num_classes (int):
number of semantic classes
"""
(N, C, H, W) = input_shape
rng = np.random.RandomState(0)
imgs = rng.rand(*input_shape)
segs = rng.randint(
low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
img_metas = [{
'img_shape': (H, W, C),
'ori_shape': (H, W, C),
'pad_shape': (H, W, C),
'filename': '<demo>.png',
'scale_factor': 1.0,
'flip': False,
'flip_direction': 'horizontal'
} for _ in range(N)]
mm_inputs = {
'imgs': torch.FloatTensor(imgs),
'img_metas': img_metas,
'gt_semantic_seg': torch.LongTensor(segs)
}
return mm_inputs
@BACKBONES.register_module()
class ExampleBackbone(nn.Module):
def __init__(self):
super(ExampleBackbone, self).__init__()
self.conv = nn.Conv2d(3, 3, 3)
def init_weights(self, pretrained=None):
pass
def forward(self, x):
return [self.conv(x)]
@HEADS.register_module()
class ExampleDecodeHead(BaseDecodeHead):
def __init__(self):
super(ExampleDecodeHead, self).__init__(3, 3, num_classes=19)
def forward(self, inputs):
return self.cls_seg(inputs[0])
@HEADS.register_module()
class ExampleCascadeDecodeHead(BaseCascadeDecodeHead):
def __init__(self):
super(ExampleCascadeDecodeHead, self).__init__(3, 3, num_classes=19)
def forward(self, inputs, prev_out):
return self.cls_seg(inputs[0])
def _segmentor_forward_train_test(segmentor):
if isinstance(segmentor.decode_head, nn.ModuleList):
num_classes = segmentor.decode_head[-1].num_classes
else:
num_classes = segmentor.decode_head.num_classes
# batch_size=2 for BatchNorm
mm_inputs = _demo_mm_inputs(num_classes=num_classes)
imgs = mm_inputs.pop('imgs')
img_metas = mm_inputs.pop('img_metas')
gt_semantic_seg = mm_inputs['gt_semantic_seg']
# convert to cuda Tensor if applicable
if torch.cuda.is_available():
segmentor = segmentor.cuda()
imgs = imgs.cuda()
gt_semantic_seg = gt_semantic_seg.cuda()
# Test forward train
losses = segmentor.forward(
imgs, img_metas, gt_semantic_seg=gt_semantic_seg, return_loss=True)
assert isinstance(losses, dict)
# Test train_step
data_batch = dict(
img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg)
outputs = segmentor.train_step(data_batch, None)
assert isinstance(outputs, dict)
assert 'loss' in outputs
assert 'log_vars' in outputs
assert 'num_samples' in outputs
# Test val_step
with torch.no_grad():
segmentor.eval()
data_batch = dict(
img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg)
outputs = segmentor.val_step(data_batch, None)
assert isinstance(outputs, dict)
assert 'loss' in outputs
assert 'log_vars' in outputs
assert 'num_samples' in outputs
# Test forward simple test
with torch.no_grad():
segmentor.eval()
# pack into lists
img_list = [img[None, :] for img in imgs]
img_meta_list = [[img_meta] for img_meta in img_metas]
segmentor.forward(img_list, img_meta_list, return_loss=False)
# Test forward aug test
with torch.no_grad():
segmentor.eval()
# pack into lists
img_list = [img[None, :] for img in imgs]
img_list = img_list + img_list
img_meta_list = [[img_meta] for img_meta in img_metas]
img_meta_list = img_meta_list + img_meta_list
segmentor.forward(img_list, img_meta_list, return_loss=False)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment