Skip to content
Snippets Groups Projects
Commit b2abe157 authored by zhengmiao's avatar zhengmiao
Browse files

Merge branch 'zhengmiao/tests_bp' into 'refactor_dev'

[Refactory] Clean UTs

See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!2
parents 24cc27dd 35c695bb
No related branches found
No related tags found
No related merge requests found
Showing
with 0 additions and 3974 deletions
# Copyright (c) OpenMMLab. All rights reserved.
import shutil
from unittest.mock import MagicMock
import numpy as np
import pytest
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, dataloader
from mmseg.apis import single_gpu_test
class ExampleDataset(Dataset):
def __getitem__(self, idx):
results = dict(img=torch.tensor([1]), img_metas=dict())
return results
def __len__(self):
return 1
class ExampleModel(nn.Module):
def __init__(self):
super(ExampleModel, self).__init__()
self.test_cfg = None
self.conv = nn.Conv2d(3, 3, 3)
def forward(self, img, img_metas, return_loss=False, **kwargs):
return img
def test_single_gpu():
test_dataset = ExampleDataset()
data_loader = DataLoader(
test_dataset,
batch_size=1,
sampler=None,
num_workers=0,
shuffle=False,
)
model = ExampleModel()
# Test efficient test compatibility (will be deprecated)
results = single_gpu_test(model, data_loader, efficient_test=True)
assert len(results) == 1
pred = np.load(results[0])
assert isinstance(pred, np.ndarray)
assert pred.shape == (1, )
assert pred[0] == 1
shutil.rmtree('.efficient_test')
# Test pre_eval
test_dataset.pre_eval = MagicMock(return_value=['success'])
results = single_gpu_test(model, data_loader, pre_eval=True)
assert results == ['success']
# Test format_only
test_dataset.format_results = MagicMock(return_value=['success'])
results = single_gpu_test(model, data_loader, format_only=True)
assert results == ['success']
# efficient_test, pre_eval and format_only are mutually exclusive
with pytest.raises(AssertionError):
single_gpu_test(
model,
dataloader,
efficient_test=True,
format_only=True,
pre_eval=True)
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved.
import math
import os.path as osp
import pytest
from torch.utils.data import (DistributedSampler, RandomSampler,
SequentialSampler)
from mmseg.datasets import (DATASETS, ConcatDataset, MultiImageMixDataset,
build_dataloader, build_dataset)
@DATASETS.register_module()
class ToyDataset(object):
def __init__(self, cnt=0):
self.cnt = cnt
def __item__(self, idx):
return idx
def __len__(self):
return 100
def test_build_dataset():
cfg = dict(type='ToyDataset')
dataset = build_dataset(cfg)
assert isinstance(dataset, ToyDataset)
assert dataset.cnt == 0
dataset = build_dataset(cfg, default_args=dict(cnt=1))
assert isinstance(dataset, ToyDataset)
assert dataset.cnt == 1
data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset')
img_dir = 'imgs/'
ann_dir = 'gts/'
# We use same dir twice for simplicity
# with ann_dir
cfg = dict(
type='CustomDataset',
pipeline=[],
data_root=data_root,
img_dir=[img_dir, img_dir],
ann_dir=[ann_dir, ann_dir])
dataset = build_dataset(cfg)
assert isinstance(dataset, ConcatDataset)
assert len(dataset) == 10
cfg = dict(type='MultiImageMixDataset', dataset=cfg, pipeline=[])
dataset = build_dataset(cfg)
assert isinstance(dataset, MultiImageMixDataset)
assert len(dataset) == 10
# with ann_dir, split
cfg = dict(
type='CustomDataset',
pipeline=[],
data_root=data_root,
img_dir=img_dir,
ann_dir=ann_dir,
split=['splits/train.txt', 'splits/val.txt'])
dataset = build_dataset(cfg)
assert isinstance(dataset, ConcatDataset)
assert len(dataset) == 5
# with ann_dir, split
cfg = dict(
type='CustomDataset',
pipeline=[],
data_root=data_root,
img_dir=img_dir,
ann_dir=[ann_dir, ann_dir],
split=['splits/train.txt', 'splits/val.txt'])
dataset = build_dataset(cfg)
assert isinstance(dataset, ConcatDataset)
assert len(dataset) == 5
# test mode
cfg = dict(
type='CustomDataset',
pipeline=[],
data_root=data_root,
img_dir=[img_dir, img_dir],
test_mode=True,
classes=('pseudo_class', ))
dataset = build_dataset(cfg)
assert isinstance(dataset, ConcatDataset)
assert len(dataset) == 10
# test mode with splits
cfg = dict(
type='CustomDataset',
pipeline=[],
data_root=data_root,
img_dir=[img_dir, img_dir],
split=['splits/val.txt', 'splits/val.txt'],
test_mode=True,
classes=('pseudo_class', ))
dataset = build_dataset(cfg)
assert isinstance(dataset, ConcatDataset)
assert len(dataset) == 2
# len(ann_dir) should be zero or len(img_dir) when len(img_dir) > 1
with pytest.raises(AssertionError):
cfg = dict(
type='CustomDataset',
pipeline=[],
data_root=data_root,
img_dir=[img_dir, img_dir],
ann_dir=[ann_dir, ann_dir, ann_dir])
build_dataset(cfg)
# len(splits) should be zero or len(img_dir) when len(img_dir) > 1
with pytest.raises(AssertionError):
cfg = dict(
type='CustomDataset',
pipeline=[],
data_root=data_root,
img_dir=[img_dir, img_dir],
split=['splits/val.txt', 'splits/val.txt', 'splits/val.txt'])
build_dataset(cfg)
# len(splits) == len(ann_dir) when only len(img_dir) == 1 and len(
# ann_dir) > 1
with pytest.raises(AssertionError):
cfg = dict(
type='CustomDataset',
pipeline=[],
data_root=data_root,
img_dir=img_dir,
ann_dir=[ann_dir, ann_dir],
split=['splits/val.txt', 'splits/val.txt', 'splits/val.txt'])
build_dataset(cfg)
def test_build_dataloader():
dataset = ToyDataset()
samples_per_gpu = 3
# dist=True, shuffle=True, 1GPU
dataloader = build_dataloader(
dataset, samples_per_gpu=samples_per_gpu, workers_per_gpu=2)
assert dataloader.batch_size == samples_per_gpu
assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu))
assert isinstance(dataloader.sampler, DistributedSampler)
assert dataloader.sampler.shuffle
# dist=True, shuffle=False, 1GPU
dataloader = build_dataloader(
dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=2,
shuffle=False)
assert dataloader.batch_size == samples_per_gpu
assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu))
assert isinstance(dataloader.sampler, DistributedSampler)
assert not dataloader.sampler.shuffle
# dist=True, shuffle=True, 8GPU
dataloader = build_dataloader(
dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=2,
num_gpus=8)
assert dataloader.batch_size == samples_per_gpu
assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu))
assert dataloader.num_workers == 2
# dist=False, shuffle=True, 1GPU
dataloader = build_dataloader(
dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=2,
dist=False)
assert dataloader.batch_size == samples_per_gpu
assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu))
assert isinstance(dataloader.sampler, RandomSampler)
assert dataloader.num_workers == 2
# dist=False, shuffle=False, 1GPU
dataloader = build_dataloader(
dataset,
samples_per_gpu=3,
workers_per_gpu=2,
shuffle=False,
dist=False)
assert dataloader.batch_size == samples_per_gpu
assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu))
assert isinstance(dataloader.sampler, SequentialSampler)
assert dataloader.num_workers == 2
# dist=False, shuffle=True, 8GPU
dataloader = build_dataloader(
dataset, samples_per_gpu=3, workers_per_gpu=2, num_gpus=8, dist=False)
assert dataloader.batch_size == samples_per_gpu * 8
assert len(dataloader) == int(
math.ceil(len(dataset) / samples_per_gpu / 8))
assert isinstance(dataloader.sampler, RandomSampler)
assert dataloader.num_workers == 16
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
import tempfile
import mmcv
import numpy as np
from mmseg.datasets.pipelines import LoadAnnotations, LoadImageFromFile
class TestLoading(object):
@classmethod
def setup_class(cls):
cls.data_prefix = osp.join(osp.dirname(__file__), '../data')
def test_load_img(self):
results = dict(
img_prefix=self.data_prefix, img_info=dict(filename='color.jpg'))
transform = LoadImageFromFile()
results = transform(copy.deepcopy(results))
assert results['filename'] == osp.join(self.data_prefix, 'color.jpg')
assert results['ori_filename'] == 'color.jpg'
assert results['img'].shape == (288, 512, 3)
assert results['img'].dtype == np.uint8
assert results['img_shape'] == (288, 512, 3)
assert results['ori_shape'] == (288, 512, 3)
assert results['pad_shape'] == (288, 512, 3)
assert results['scale_factor'] == 1.0
np.testing.assert_equal(results['img_norm_cfg']['mean'],
np.zeros(3, dtype=np.float32))
assert repr(transform) == transform.__class__.__name__ + \
"(to_float32=False,color_type='color',imdecode_backend='cv2')"
# no img_prefix
results = dict(
img_prefix=None, img_info=dict(filename='tests/data/color.jpg'))
transform = LoadImageFromFile()
results = transform(copy.deepcopy(results))
assert results['filename'] == 'tests/data/color.jpg'
assert results['ori_filename'] == 'tests/data/color.jpg'
assert results['img'].shape == (288, 512, 3)
# to_float32
transform = LoadImageFromFile(to_float32=True)
results = transform(copy.deepcopy(results))
assert results['img'].dtype == np.float32
# gray image
results = dict(
img_prefix=self.data_prefix, img_info=dict(filename='gray.jpg'))
transform = LoadImageFromFile()
results = transform(copy.deepcopy(results))
assert results['img'].shape == (288, 512, 3)
assert results['img'].dtype == np.uint8
transform = LoadImageFromFile(color_type='unchanged')
results = transform(copy.deepcopy(results))
assert results['img'].shape == (288, 512)
assert results['img'].dtype == np.uint8
np.testing.assert_equal(results['img_norm_cfg']['mean'],
np.zeros(1, dtype=np.float32))
def test_load_seg(self):
results = dict(
seg_prefix=self.data_prefix,
ann_info=dict(seg_map='seg.png'),
seg_fields=[])
transform = LoadAnnotations()
results = transform(copy.deepcopy(results))
assert results['seg_fields'] == ['gt_semantic_seg']
assert results['gt_semantic_seg'].shape == (288, 512)
assert results['gt_semantic_seg'].dtype == np.uint8
assert repr(transform) == transform.__class__.__name__ + \
"(reduce_zero_label=False,imdecode_backend='pillow')"
# no img_prefix
results = dict(
seg_prefix=None,
ann_info=dict(seg_map='tests/data/seg.png'),
seg_fields=[])
transform = LoadAnnotations()
results = transform(copy.deepcopy(results))
assert results['gt_semantic_seg'].shape == (288, 512)
assert results['gt_semantic_seg'].dtype == np.uint8
# reduce_zero_label
transform = LoadAnnotations(reduce_zero_label=True)
results = transform(copy.deepcopy(results))
assert results['gt_semantic_seg'].shape == (288, 512)
assert results['gt_semantic_seg'].dtype == np.uint8
# mmcv backend
results = dict(
seg_prefix=self.data_prefix,
ann_info=dict(seg_map='seg.png'),
seg_fields=[])
transform = LoadAnnotations(imdecode_backend='pillow')
results = transform(copy.deepcopy(results))
# this image is saved by PIL
assert results['gt_semantic_seg'].shape == (288, 512)
assert results['gt_semantic_seg'].dtype == np.uint8
def test_load_seg_custom_classes(self):
test_img = np.random.rand(10, 10)
test_gt = np.zeros_like(test_img)
test_gt[2:4, 2:4] = 1
test_gt[2:4, 6:8] = 2
test_gt[6:8, 2:4] = 3
test_gt[6:8, 6:8] = 4
tmp_dir = tempfile.TemporaryDirectory()
img_path = osp.join(tmp_dir.name, 'img.jpg')
gt_path = osp.join(tmp_dir.name, 'gt.png')
mmcv.imwrite(test_img, img_path)
mmcv.imwrite(test_gt, gt_path)
# test only train with label with id 3
results = dict(
img_info=dict(filename=img_path),
ann_info=dict(seg_map=gt_path),
label_map={
0: 0,
1: 0,
2: 0,
3: 1,
4: 0
},
seg_fields=[])
load_imgs = LoadImageFromFile()
results = load_imgs(copy.deepcopy(results))
load_anns = LoadAnnotations()
results = load_anns(copy.deepcopy(results))
gt_array = results['gt_semantic_seg']
true_mask = np.zeros_like(gt_array)
true_mask[6:8, 2:4] = 1
assert results['seg_fields'] == ['gt_semantic_seg']
assert gt_array.shape == (10, 10)
assert gt_array.dtype == np.uint8
np.testing.assert_array_equal(gt_array, true_mask)
# test only train with label with id 4 and 3
results = dict(
img_info=dict(filename=img_path),
ann_info=dict(seg_map=gt_path),
label_map={
0: 0,
1: 0,
2: 0,
3: 2,
4: 1
},
seg_fields=[])
load_imgs = LoadImageFromFile()
results = load_imgs(copy.deepcopy(results))
load_anns = LoadAnnotations()
results = load_anns(copy.deepcopy(results))
gt_array = results['gt_semantic_seg']
true_mask = np.zeros_like(gt_array)
true_mask[6:8, 2:4] = 2
true_mask[6:8, 6:8] = 1
assert results['seg_fields'] == ['gt_semantic_seg']
assert gt_array.shape == (10, 10)
assert gt_array.dtype == np.uint8
np.testing.assert_array_equal(gt_array, true_mask)
# test no custom classes
results = dict(
img_info=dict(filename=img_path),
ann_info=dict(seg_map=gt_path),
seg_fields=[])
load_imgs = LoadImageFromFile()
results = load_imgs(copy.deepcopy(results))
load_anns = LoadAnnotations()
results = load_anns(copy.deepcopy(results))
gt_array = results['gt_semantic_seg']
assert results['seg_fields'] == ['gt_semantic_seg']
assert gt_array.shape == (10, 10)
assert gt_array.dtype == np.uint8
np.testing.assert_array_equal(gt_array, test_gt)
tmp_dir.cleanup()
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import mmcv
import pytest
from mmcv.utils import build_from_cfg
from mmseg.datasets.builder import PIPELINES
def test_multi_scale_flip_aug():
# test assertion if img_scale=None, img_ratios=1 (not float).
with pytest.raises(AssertionError):
tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=None,
img_ratios=1,
transforms=[dict(type='Resize', keep_ratio=False)],
)
build_from_cfg(tta_transform, PIPELINES)
# test assertion if img_scale=None, img_ratios=None.
with pytest.raises(AssertionError):
tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=None,
img_ratios=None,
transforms=[dict(type='Resize', keep_ratio=False)],
)
build_from_cfg(tta_transform, PIPELINES)
# test assertion if img_scale=(512, 512), img_ratios=1 (not float).
with pytest.raises(AssertionError):
tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=(512, 512),
img_ratios=1,
transforms=[dict(type='Resize', keep_ratio=False)],
)
build_from_cfg(tta_transform, PIPELINES)
tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=(512, 512),
img_ratios=[0.5, 1.0, 2.0],
flip=False,
transforms=[dict(type='Resize', keep_ratio=False)],
)
tta_module = build_from_cfg(tta_transform, PIPELINES)
results = dict()
# (288, 512, 3)
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
tta_results = tta_module(results.copy())
assert tta_results['scale'] == [(256, 256), (512, 512), (1024, 1024)]
assert tta_results['flip'] == [False, False, False]
tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=(512, 512),
img_ratios=[0.5, 1.0, 2.0],
flip=True,
transforms=[dict(type='Resize', keep_ratio=False)],
)
tta_module = build_from_cfg(tta_transform, PIPELINES)
tta_results = tta_module(results.copy())
assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512),
(512, 512), (1024, 1024), (1024, 1024)]
assert tta_results['flip'] == [False, True, False, True, False, True]
tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=(512, 512),
img_ratios=1.0,
flip=False,
transforms=[dict(type='Resize', keep_ratio=False)],
)
tta_module = build_from_cfg(tta_transform, PIPELINES)
tta_results = tta_module(results.copy())
assert tta_results['scale'] == [(512, 512)]
assert tta_results['flip'] == [False]
tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=(512, 512),
img_ratios=1.0,
flip=True,
transforms=[dict(type='Resize', keep_ratio=False)],
)
tta_module = build_from_cfg(tta_transform, PIPELINES)
tta_results = tta_module(results.copy())
assert tta_results['scale'] == [(512, 512), (512, 512)]
assert tta_results['flip'] == [False, True]
tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=None,
img_ratios=[0.5, 1.0, 2.0],
flip=False,
transforms=[dict(type='Resize', keep_ratio=False)],
)
tta_module = build_from_cfg(tta_transform, PIPELINES)
tta_results = tta_module(results.copy())
assert tta_results['scale'] == [(256, 144), (512, 288), (1024, 576)]
assert tta_results['flip'] == [False, False, False]
tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=None,
img_ratios=[0.5, 1.0, 2.0],
flip=True,
transforms=[dict(type='Resize', keep_ratio=False)],
)
tta_module = build_from_cfg(tta_transform, PIPELINES)
tta_results = tta_module(results.copy())
assert tta_results['scale'] == [(256, 144), (256, 144), (512, 288),
(512, 288), (1024, 576), (1024, 576)]
assert tta_results['flip'] == [False, True, False, True, False, True]
tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=[(256, 256), (512, 512), (1024, 1024)],
img_ratios=None,
flip=False,
transforms=[dict(type='Resize', keep_ratio=False)],
)
tta_module = build_from_cfg(tta_transform, PIPELINES)
tta_results = tta_module(results.copy())
assert tta_results['scale'] == [(256, 256), (512, 512), (1024, 1024)]
assert tta_results['flip'] == [False, False, False]
tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=[(256, 256), (512, 512), (1024, 1024)],
img_ratios=None,
flip=True,
transforms=[dict(type='Resize', keep_ratio=False)],
)
tta_module = build_from_cfg(tta_transform, PIPELINES)
tta_results = tta_module(results.copy())
assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512),
(512, 512), (1024, 1024), (1024, 1024)]
assert tta_results['flip'] == [False, True, False, True, False, True]
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import tempfile
from unittest.mock import MagicMock, patch
import mmcv.runner
import pytest
import torch
import torch.nn as nn
from mmcv.runner import obj_from_dict
from torch.utils.data import DataLoader, Dataset
from mmseg.apis import single_gpu_test
from mmseg.core import DistEvalHook, EvalHook
class ExampleDataset(Dataset):
def __getitem__(self, idx):
results = dict(img=torch.tensor([1]), img_metas=dict())
return results
def __len__(self):
return 1
class ExampleModel(nn.Module):
def __init__(self):
super(ExampleModel, self).__init__()
self.test_cfg = None
self.conv = nn.Conv2d(3, 3, 3)
def forward(self, img, img_metas, test_mode=False, **kwargs):
return img
def train_step(self, data_batch, optimizer):
loss = self.forward(**data_batch)
return dict(loss=loss)
def test_iter_eval_hook():
with pytest.raises(TypeError):
test_dataset = ExampleModel()
data_loader = [
DataLoader(
test_dataset,
batch_size=1,
sampler=None,
num_worker=0,
shuffle=False)
]
EvalHook(data_loader)
test_dataset = ExampleDataset()
test_dataset.pre_eval = MagicMock(return_value=[torch.tensor([1])])
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
loader = DataLoader(test_dataset, batch_size=1)
model = ExampleModel()
data_loader = DataLoader(
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer = obj_from_dict(optim_cfg, torch.optim,
dict(params=model.parameters()))
# test EvalHook
with tempfile.TemporaryDirectory() as tmpdir:
eval_hook = EvalHook(data_loader, by_epoch=False, efficient_test=True)
runner = mmcv.runner.IterBasedRunner(
model=model,
optimizer=optimizer,
work_dir=tmpdir,
logger=logging.getLogger())
runner.register_hook(eval_hook)
runner.run([loader], [('train', 1)], 1)
test_dataset.evaluate.assert_called_with([torch.tensor([1])],
logger=runner.logger)
def test_epoch_eval_hook():
with pytest.raises(TypeError):
test_dataset = ExampleModel()
data_loader = [
DataLoader(
test_dataset,
batch_size=1,
sampler=None,
num_worker=0,
shuffle=False)
]
EvalHook(data_loader, by_epoch=True)
test_dataset = ExampleDataset()
test_dataset.pre_eval = MagicMock(return_value=[torch.tensor([1])])
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
loader = DataLoader(test_dataset, batch_size=1)
model = ExampleModel()
data_loader = DataLoader(
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer = obj_from_dict(optim_cfg, torch.optim,
dict(params=model.parameters()))
# test EvalHook with interval
with tempfile.TemporaryDirectory() as tmpdir:
eval_hook = EvalHook(data_loader, by_epoch=True, interval=2)
runner = mmcv.runner.EpochBasedRunner(
model=model,
optimizer=optimizer,
work_dir=tmpdir,
logger=logging.getLogger())
runner.register_hook(eval_hook)
runner.run([loader], [('train', 1)], 2)
test_dataset.evaluate.assert_called_once_with([torch.tensor([1])],
logger=runner.logger)
def multi_gpu_test(model,
data_loader,
tmpdir=None,
gpu_collect=False,
pre_eval=False):
# Pre eval is set by default when training.
results = single_gpu_test(model, data_loader, pre_eval=True)
return results
@patch('mmseg.apis.multi_gpu_test', multi_gpu_test)
def test_dist_eval_hook():
with pytest.raises(TypeError):
test_dataset = ExampleModel()
data_loader = [
DataLoader(
test_dataset,
batch_size=1,
sampler=None,
num_worker=0,
shuffle=False)
]
DistEvalHook(data_loader)
test_dataset = ExampleDataset()
test_dataset.pre_eval = MagicMock(return_value=[torch.tensor([1])])
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
loader = DataLoader(test_dataset, batch_size=1)
model = ExampleModel()
data_loader = DataLoader(
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer = obj_from_dict(optim_cfg, torch.optim,
dict(params=model.parameters()))
# test DistEvalHook
with tempfile.TemporaryDirectory() as tmpdir:
eval_hook = DistEvalHook(
data_loader, by_epoch=False, efficient_test=True)
runner = mmcv.runner.IterBasedRunner(
model=model,
optimizer=optimizer,
work_dir=tmpdir,
logger=logging.getLogger())
runner.register_hook(eval_hook)
runner.run([loader], [('train', 1)], 1)
test_dataset.evaluate.assert_called_with([torch.tensor([1])],
logger=runner.logger)
@patch('mmseg.apis.multi_gpu_test', multi_gpu_test)
def test_dist_eval_hook_epoch():
with pytest.raises(TypeError):
test_dataset = ExampleModel()
data_loader = [
DataLoader(
test_dataset,
batch_size=1,
sampler=None,
num_worker=0,
shuffle=False)
]
DistEvalHook(data_loader)
test_dataset = ExampleDataset()
test_dataset.pre_eval = MagicMock(return_value=[torch.tensor([1])])
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
loader = DataLoader(test_dataset, batch_size=1)
model = ExampleModel()
data_loader = DataLoader(
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer = obj_from_dict(optim_cfg, torch.optim,
dict(params=model.parameters()))
# test DistEvalHook
with tempfile.TemporaryDirectory() as tmpdir:
eval_hook = DistEvalHook(data_loader, by_epoch=True, interval=2)
runner = mmcv.runner.EpochBasedRunner(
model=model,
optimizer=optimizer,
work_dir=tmpdir,
logger=logging.getLogger())
runner.register_hook(eval_hook)
runner.run([loader], [('train', 1)], 2)
test_dataset.evaluate.assert_called_with([torch.tensor([1])],
logger=runner.logger)
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import mmcv
from mmseg.apis import inference_segmentor, init_segmentor
def test_test_time_augmentation_on_cpu():
config_file = 'configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py'
config = mmcv.Config.fromfile(config_file)
# Remove pretrain model download for testing
config.model.pretrained = None
# Replace SyncBN with BN to inference on CPU
norm_cfg = dict(type='BN', requires_grad=True)
config.model.backbone.norm_cfg = norm_cfg
config.model.decode_head.norm_cfg = norm_cfg
config.model.auxiliary_head.norm_cfg = norm_cfg
# Enable test time augmentation
config.data.test.pipeline[1].flip = True
checkpoint_file = None
model = init_segmentor(config, checkpoint_file, device='cpu')
img = mmcv.imread(
osp.join(osp.dirname(__file__), 'data/color.jpg'), 'color')
result = inference_segmentor(model, img)
assert result[0].shape == (288, 512)
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmseg.core.evaluation import (eval_metrics, mean_dice, mean_fscore,
mean_iou)
from mmseg.core.evaluation.metrics import f_score
def get_confusion_matrix(pred_label, label, num_classes, ignore_index):
"""Intersection over Union
Args:
pred_label (np.ndarray): 2D predict map
label (np.ndarray): label 2D label map
num_classes (int): number of categories
ignore_index (int): index ignore in evaluation
"""
mask = (label != ignore_index)
pred_label = pred_label[mask]
label = label[mask]
n = num_classes
inds = n * label + pred_label
mat = np.bincount(inds, minlength=n**2).reshape(n, n)
return mat
# This func is deprecated since it's not memory efficient
def legacy_mean_iou(results, gt_seg_maps, num_classes, ignore_index):
num_imgs = len(results)
assert len(gt_seg_maps) == num_imgs
total_mat = np.zeros((num_classes, num_classes), dtype=np.float)
for i in range(num_imgs):
mat = get_confusion_matrix(
results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index)
total_mat += mat
all_acc = np.diag(total_mat).sum() / total_mat.sum()
acc = np.diag(total_mat) / total_mat.sum(axis=1)
iou = np.diag(total_mat) / (
total_mat.sum(axis=1) + total_mat.sum(axis=0) - np.diag(total_mat))
return all_acc, acc, iou
# This func is deprecated since it's not memory efficient
def legacy_mean_dice(results, gt_seg_maps, num_classes, ignore_index):
num_imgs = len(results)
assert len(gt_seg_maps) == num_imgs
total_mat = np.zeros((num_classes, num_classes), dtype=np.float)
for i in range(num_imgs):
mat = get_confusion_matrix(
results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index)
total_mat += mat
all_acc = np.diag(total_mat).sum() / total_mat.sum()
acc = np.diag(total_mat) / total_mat.sum(axis=1)
dice = 2 * np.diag(total_mat) / (
total_mat.sum(axis=1) + total_mat.sum(axis=0))
return all_acc, acc, dice
# This func is deprecated since it's not memory efficient
def legacy_mean_fscore(results,
gt_seg_maps,
num_classes,
ignore_index,
beta=1):
num_imgs = len(results)
assert len(gt_seg_maps) == num_imgs
total_mat = np.zeros((num_classes, num_classes), dtype=np.float)
for i in range(num_imgs):
mat = get_confusion_matrix(
results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index)
total_mat += mat
all_acc = np.diag(total_mat).sum() / total_mat.sum()
recall = np.diag(total_mat) / total_mat.sum(axis=1)
precision = np.diag(total_mat) / total_mat.sum(axis=0)
fv = np.vectorize(f_score)
fscore = fv(precision, recall, beta=beta)
return all_acc, recall, precision, fscore
def test_metrics():
pred_size = (10, 30, 30)
num_classes = 19
ignore_index = 255
results = np.random.randint(0, num_classes, size=pred_size)
label = np.random.randint(0, num_classes, size=pred_size)
# Test the availability of arg: ignore_index.
label[:, 2, 5:10] = ignore_index
# Test the correctness of the implementation of mIoU calculation.
ret_metrics = eval_metrics(
results, label, num_classes, ignore_index, metrics='mIoU')
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
'IoU']
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes,
ignore_index)
assert all_acc == all_acc_l
assert np.allclose(acc, acc_l)
assert np.allclose(iou, iou_l)
# Test the correctness of the implementation of mDice calculation.
ret_metrics = eval_metrics(
results, label, num_classes, ignore_index, metrics='mDice')
all_acc, acc, dice = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
'Dice']
all_acc_l, acc_l, dice_l = legacy_mean_dice(results, label, num_classes,
ignore_index)
assert all_acc == all_acc_l
assert np.allclose(acc, acc_l)
assert np.allclose(dice, dice_l)
# Test the correctness of the implementation of mDice calculation.
ret_metrics = eval_metrics(
results, label, num_classes, ignore_index, metrics='mFscore')
all_acc, recall, precision, fscore = ret_metrics['aAcc'], ret_metrics[
'Recall'], ret_metrics['Precision'], ret_metrics['Fscore']
all_acc_l, recall_l, precision_l, fscore_l = legacy_mean_fscore(
results, label, num_classes, ignore_index)
assert all_acc == all_acc_l
assert np.allclose(recall, recall_l)
assert np.allclose(precision, precision_l)
assert np.allclose(fscore, fscore_l)
# Test the correctness of the implementation of joint calculation.
ret_metrics = eval_metrics(
results,
label,
num_classes,
ignore_index,
metrics=['mIoU', 'mDice', 'mFscore'])
all_acc, acc, iou, dice, precision, recall, fscore = ret_metrics[
'aAcc'], ret_metrics['Acc'], ret_metrics['IoU'], ret_metrics[
'Dice'], ret_metrics['Precision'], ret_metrics[
'Recall'], ret_metrics['Fscore']
assert all_acc == all_acc_l
assert np.allclose(acc, acc_l)
assert np.allclose(iou, iou_l)
assert np.allclose(dice, dice_l)
assert np.allclose(precision, precision_l)
assert np.allclose(recall, recall_l)
assert np.allclose(fscore, fscore_l)
# Test the correctness of calculation when arg: num_classes is larger
# than the maximum value of input maps.
results = np.random.randint(0, 5, size=pred_size)
label = np.random.randint(0, 4, size=pred_size)
ret_metrics = eval_metrics(
results,
label,
num_classes,
ignore_index=255,
metrics='mIoU',
nan_to_num=-1)
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
'IoU']
assert acc[-1] == -1
assert iou[-1] == -1
ret_metrics = eval_metrics(
results,
label,
num_classes,
ignore_index=255,
metrics='mDice',
nan_to_num=-1)
all_acc, acc, dice = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
'Dice']
assert acc[-1] == -1
assert dice[-1] == -1
ret_metrics = eval_metrics(
results,
label,
num_classes,
ignore_index=255,
metrics='mFscore',
nan_to_num=-1)
all_acc, precision, recall, fscore = ret_metrics['aAcc'], ret_metrics[
'Precision'], ret_metrics['Recall'], ret_metrics['Fscore']
assert precision[-1] == -1
assert recall[-1] == -1
assert fscore[-1] == -1
ret_metrics = eval_metrics(
results,
label,
num_classes,
ignore_index=255,
metrics=['mDice', 'mIoU', 'mFscore'],
nan_to_num=-1)
all_acc, acc, iou, dice, precision, recall, fscore = ret_metrics[
'aAcc'], ret_metrics['Acc'], ret_metrics['IoU'], ret_metrics[
'Dice'], ret_metrics['Precision'], ret_metrics[
'Recall'], ret_metrics['Fscore']
assert acc[-1] == -1
assert dice[-1] == -1
assert iou[-1] == -1
assert precision[-1] == -1
assert recall[-1] == -1
assert fscore[-1] == -1
# Test the bug which is caused by torch.histc.
# torch.histc: https://pytorch.org/docs/stable/generated/torch.histc.html
# When the arg:bins is set to be same as arg:max,
# some channels of mIoU may be nan.
results = np.array([np.repeat(31, 59)])
label = np.array([np.arange(59)])
num_classes = 59
ret_metrics = eval_metrics(
results, label, num_classes, ignore_index=255, metrics='mIoU')
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
'IoU']
assert not np.any(np.isnan(iou))
def test_mean_iou():
pred_size = (10, 30, 30)
num_classes = 19
ignore_index = 255
results = np.random.randint(0, num_classes, size=pred_size)
label = np.random.randint(0, num_classes, size=pred_size)
label[:, 2, 5:10] = ignore_index
ret_metrics = mean_iou(results, label, num_classes, ignore_index)
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
'IoU']
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes,
ignore_index)
assert all_acc == all_acc_l
assert np.allclose(acc, acc_l)
assert np.allclose(iou, iou_l)
results = np.random.randint(0, 5, size=pred_size)
label = np.random.randint(0, 4, size=pred_size)
ret_metrics = mean_iou(
results, label, num_classes, ignore_index=255, nan_to_num=-1)
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
'IoU']
assert acc[-1] == -1
assert acc[-1] == -1
def test_mean_dice():
pred_size = (10, 30, 30)
num_classes = 19
ignore_index = 255
results = np.random.randint(0, num_classes, size=pred_size)
label = np.random.randint(0, num_classes, size=pred_size)
label[:, 2, 5:10] = ignore_index
ret_metrics = mean_dice(results, label, num_classes, ignore_index)
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
'Dice']
all_acc_l, acc_l, dice_l = legacy_mean_dice(results, label, num_classes,
ignore_index)
assert all_acc == all_acc_l
assert np.allclose(acc, acc_l)
assert np.allclose(iou, dice_l)
results = np.random.randint(0, 5, size=pred_size)
label = np.random.randint(0, 4, size=pred_size)
ret_metrics = mean_dice(
results, label, num_classes, ignore_index=255, nan_to_num=-1)
all_acc, acc, dice = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
'Dice']
assert acc[-1] == -1
assert dice[-1] == -1
def test_mean_fscore():
pred_size = (10, 30, 30)
num_classes = 19
ignore_index = 255
results = np.random.randint(0, num_classes, size=pred_size)
label = np.random.randint(0, num_classes, size=pred_size)
label[:, 2, 5:10] = ignore_index
ret_metrics = mean_fscore(results, label, num_classes, ignore_index)
all_acc, recall, precision, fscore = ret_metrics['aAcc'], ret_metrics[
'Recall'], ret_metrics['Precision'], ret_metrics['Fscore']
all_acc_l, recall_l, precision_l, fscore_l = legacy_mean_fscore(
results, label, num_classes, ignore_index)
assert all_acc == all_acc_l
assert np.allclose(recall, recall_l)
assert np.allclose(precision, precision_l)
assert np.allclose(fscore, fscore_l)
ret_metrics = mean_fscore(
results, label, num_classes, ignore_index, beta=2)
all_acc, recall, precision, fscore = ret_metrics['aAcc'], ret_metrics[
'Recall'], ret_metrics['Precision'], ret_metrics['Fscore']
all_acc_l, recall_l, precision_l, fscore_l = legacy_mean_fscore(
results, label, num_classes, ignore_index, beta=2)
assert all_acc == all_acc_l
assert np.allclose(recall, recall_l)
assert np.allclose(precision, precision_l)
assert np.allclose(fscore, fscore_l)
results = np.random.randint(0, 5, size=pred_size)
label = np.random.randint(0, 4, size=pred_size)
ret_metrics = mean_fscore(
results, label, num_classes, ignore_index=255, nan_to_num=-1)
all_acc, recall, precision, fscore = ret_metrics['aAcc'], ret_metrics[
'Recall'], ret_metrics['Precision'], ret_metrics['Fscore']
assert recall[-1] == -1
assert precision[-1] == -1
assert fscore[-1] == -1
def test_filename_inputs():
import tempfile
import cv2
def save_arr(input_arrays: list, title: str, is_image: bool, dir: str):
filenames = []
SUFFIX = '.png' if is_image else '.npy'
for idx, arr in enumerate(input_arrays):
filename = '{}/{}-{}{}'.format(dir, title, idx, SUFFIX)
if is_image:
cv2.imwrite(filename, arr)
else:
np.save(filename, arr)
filenames.append(filename)
return filenames
pred_size = (10, 30, 30)
num_classes = 19
ignore_index = 255
results = np.random.randint(0, num_classes, size=pred_size)
labels = np.random.randint(0, num_classes, size=pred_size)
labels[:, 2, 5:10] = ignore_index
with tempfile.TemporaryDirectory() as temp_dir:
result_files = save_arr(results, 'pred', False, temp_dir)
label_files = save_arr(labels, 'label', True, temp_dir)
ret_metrics = eval_metrics(
result_files,
label_files,
num_classes,
ignore_index,
metrics='mIoU')
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics[
'Acc'], ret_metrics['IoU']
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, labels, num_classes,
ignore_index)
assert all_acc == all_acc_l
assert np.allclose(acc, acc_l)
assert np.allclose(iou, iou_l)
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmseg.models.decode_heads import DAHead
from .utils import to_cuda
def test_da_head():
inputs = [torch.randn(1, 16, 23, 23)]
head = DAHead(in_channels=16, channels=8, num_classes=19, pam_channels=8)
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
outputs = head(inputs)
assert isinstance(outputs, tuple) and len(outputs) == 3
for output in outputs:
assert output.shape == (1, head.num_classes, 23, 23)
test_output = head.forward_test(inputs, None, None)
assert test_output.shape == (1, head.num_classes, 23, 23)
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import patch
import pytest
import torch
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from .utils import to_cuda
@patch.multiple(BaseDecodeHead, __abstractmethods__=set())
def test_decode_head():
with pytest.raises(AssertionError):
# default input_transform doesn't accept multiple inputs
BaseDecodeHead([32, 16], 16, num_classes=19)
with pytest.raises(AssertionError):
# default input_transform doesn't accept multiple inputs
BaseDecodeHead(32, 16, num_classes=19, in_index=[-1, -2])
with pytest.raises(AssertionError):
# supported mode is resize_concat only
BaseDecodeHead(32, 16, num_classes=19, input_transform='concat')
with pytest.raises(AssertionError):
# in_channels should be list|tuple
BaseDecodeHead(32, 16, num_classes=19, input_transform='resize_concat')
with pytest.raises(AssertionError):
# in_index should be list|tuple
BaseDecodeHead([32],
16,
in_index=-1,
num_classes=19,
input_transform='resize_concat')
with pytest.raises(AssertionError):
# len(in_index) should equal len(in_channels)
BaseDecodeHead([32, 16],
16,
num_classes=19,
in_index=[-1],
input_transform='resize_concat')
# test default dropout
head = BaseDecodeHead(32, 16, num_classes=19)
assert hasattr(head, 'dropout') and head.dropout.p == 0.1
# test set dropout
head = BaseDecodeHead(32, 16, num_classes=19, dropout_ratio=0.2)
assert hasattr(head, 'dropout') and head.dropout.p == 0.2
# test no input_transform
inputs = [torch.randn(1, 32, 45, 45)]
head = BaseDecodeHead(32, 16, num_classes=19)
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
assert head.in_channels == 32
assert head.input_transform is None
transformed_inputs = head._transform_inputs(inputs)
assert transformed_inputs.shape == (1, 32, 45, 45)
# test input_transform = resize_concat
inputs = [torch.randn(1, 32, 45, 45), torch.randn(1, 16, 21, 21)]
head = BaseDecodeHead([32, 16],
16,
num_classes=19,
in_index=[0, 1],
input_transform='resize_concat')
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
assert head.in_channels == 48
assert head.input_transform == 'resize_concat'
transformed_inputs = head._transform_inputs(inputs)
assert transformed_inputs.shape == (1, 48, 45, 45)
# test multi-loss, loss_decode is dict
with pytest.raises(TypeError):
# loss_decode must be a dict or sequence of dict.
BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss'])
inputs = torch.randn(2, 19, 8, 8).float()
target = torch.ones(2, 1, 64, 64).long()
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss = head.losses(seg_logit=inputs, seg_label=target)
assert 'loss_ce' in loss
# test multi-loss, loss_decode is list of dict
inputs = torch.randn(2, 19, 8, 8).float()
target = torch.ones(2, 1, 64, 64).long()
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=[
dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2')
])
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss = head.losses(seg_logit=inputs, seg_label=target)
assert 'loss_1' in loss
assert 'loss_2' in loss
# 'loss_decode' must be a dict or sequence of dict
with pytest.raises(TypeError):
BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss'])
with pytest.raises(TypeError):
BaseDecodeHead(3, 16, num_classes=19, loss_decode=0)
# test multi-loss, loss_decode is list of dict
inputs = torch.randn(2, 19, 8, 8).float()
target = torch.ones(2, 1, 64, 64).long()
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2'),
dict(type='CrossEntropyLoss', loss_name='loss_3')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss = head.losses(seg_logit=inputs, seg_label=target)
assert 'loss_1' in loss
assert 'loss_2' in loss
assert 'loss_3' in loss
# test multi-loss, loss_decode is list of dict, names of them are identical
inputs = torch.randn(2, 19, 8, 8).float()
target = torch.ones(2, 1, 64, 64).long()
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce'),
dict(type='CrossEntropyLoss', loss_name='loss_ce'),
dict(type='CrossEntropyLoss', loss_name='loss_ce')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss_3 = head.losses(seg_logit=inputs, seg_label=target)
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss = head.losses(seg_logit=inputs, seg_label=target)
assert 'loss_ce' in loss
assert 'loss_ce' in loss_3
assert loss_3['loss_ce'] == 3 * loss['loss_ce']
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmseg.models.decode_heads import EncHead
from .utils import to_cuda
def test_enc_head():
# with se_loss, w.o. lateral
inputs = [torch.randn(1, 8, 21, 21)]
head = EncHead(in_channels=[8], channels=4, num_classes=19, in_index=[-1])
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
outputs = head(inputs)
assert isinstance(outputs, tuple) and len(outputs) == 2
assert outputs[0].shape == (1, head.num_classes, 21, 21)
assert outputs[1].shape == (1, head.num_classes)
# w.o se_loss, w.o. lateral
inputs = [torch.randn(1, 8, 21, 21)]
head = EncHead(
in_channels=[8],
channels=4,
use_se_loss=False,
num_classes=19,
in_index=[-1])
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 21, 21)
# with se_loss, with lateral
inputs = [torch.randn(1, 4, 45, 45), torch.randn(1, 8, 21, 21)]
head = EncHead(
in_channels=[4, 8],
channels=4,
add_lateral=True,
num_classes=19,
in_index=[-2, -1])
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
outputs = head(inputs)
assert isinstance(outputs, tuple) and len(outputs) == 2
assert outputs[0].shape == (1, head.num_classes, 21, 21)
assert outputs[1].shape == (1, head.num_classes)
test_output = head.forward_test(inputs, None, None)
assert test_output.shape == (1, head.num_classes, 21, 21)
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmseg.models.decode_heads.knet_head import (IterativeDecodeHead,
KernelUpdateHead)
from .utils import to_cuda
num_stages = 3
conv_kernel_size = 1
kernel_updator_cfg = dict(
type='KernelUpdator',
in_channels=16,
feat_channels=16,
out_channels=16,
gate_norm_act=True,
activate_out=True,
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN'))
def test_knet_head():
# test init function of kernel update head
kernel_update_head = KernelUpdateHead(
num_classes=150,
num_ffn_fcs=2,
num_heads=8,
num_mask_fcs=1,
feedforward_channels=128,
in_channels=32,
out_channels=32,
dropout=0.0,
conv_kernel_size=conv_kernel_size,
ffn_act_cfg=dict(type='ReLU', inplace=True),
with_ffn=True,
feat_transform_cfg=dict(conv_cfg=dict(type='Conv2d'), act_cfg=None),
kernel_init=True,
kernel_updator_cfg=kernel_updator_cfg)
kernel_update_head.init_weights()
head = IterativeDecodeHead(
num_stages=num_stages,
kernel_update_head=[
dict(
type='KernelUpdateHead',
num_classes=150,
num_ffn_fcs=2,
num_heads=8,
num_mask_fcs=1,
feedforward_channels=128,
in_channels=32,
out_channels=32,
dropout=0.0,
conv_kernel_size=conv_kernel_size,
ffn_act_cfg=dict(type='ReLU', inplace=True),
with_ffn=True,
feat_transform_cfg=dict(
conv_cfg=dict(type='Conv2d'), act_cfg=None),
kernel_init=False,
kernel_updator_cfg=kernel_updator_cfg)
for _ in range(num_stages)
],
kernel_generate_head=dict(
type='FCNHead',
in_channels=128,
in_index=3,
channels=32,
num_convs=2,
concat_input=True,
dropout_ratio=0.1,
num_classes=150,
align_corners=False))
head.init_weights()
inputs = [
torch.randn(1, 16, 27, 32),
torch.randn(1, 32, 27, 16),
torch.randn(1, 64, 27, 16),
torch.randn(1, 128, 27, 16)
]
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
outputs = head(inputs)
assert outputs[-1].shape == (1, head.num_classes, 27, 16)
# test whether only return the prediction of
# the last stage during testing
with torch.no_grad():
head.eval()
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 27, 16)
# test K-Net without `feat_transform_cfg`
head = IterativeDecodeHead(
num_stages=num_stages,
kernel_update_head=[
dict(
type='KernelUpdateHead',
num_classes=150,
num_ffn_fcs=2,
num_heads=8,
num_mask_fcs=1,
feedforward_channels=128,
in_channels=32,
out_channels=32,
dropout=0.0,
conv_kernel_size=conv_kernel_size,
ffn_act_cfg=dict(type='ReLU', inplace=True),
with_ffn=True,
feat_transform_cfg=None,
kernel_updator_cfg=kernel_updator_cfg)
for _ in range(num_stages)
],
kernel_generate_head=dict(
type='FCNHead',
in_channels=128,
in_index=3,
channels=32,
num_convs=2,
concat_input=True,
dropout_ratio=0.1,
num_classes=150,
align_corners=False))
head.init_weights()
inputs = [
torch.randn(1, 16, 27, 32),
torch.randn(1, 32, 27, 16),
torch.randn(1, 64, 27, 16),
torch.randn(1, 128, 27, 16)
]
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
outputs = head(inputs)
assert outputs[-1].shape == (1, head.num_classes, 27, 16)
# test K-Net with
# self.mask_transform_stride == 2 and self.feat_gather_stride == 1
head = IterativeDecodeHead(
num_stages=num_stages,
kernel_update_head=[
dict(
type='KernelUpdateHead',
num_classes=150,
num_ffn_fcs=2,
num_heads=8,
num_mask_fcs=1,
feedforward_channels=128,
in_channels=32,
out_channels=32,
dropout=0.0,
conv_kernel_size=conv_kernel_size,
ffn_act_cfg=dict(type='ReLU', inplace=True),
with_ffn=True,
feat_transform_cfg=dict(
conv_cfg=dict(type='Conv2d'), act_cfg=None),
kernel_init=False,
mask_transform_stride=2,
feat_gather_stride=1,
kernel_updator_cfg=kernel_updator_cfg)
for _ in range(num_stages)
],
kernel_generate_head=dict(
type='FCNHead',
in_channels=128,
in_index=3,
channels=32,
num_convs=2,
concat_input=True,
dropout_ratio=0.1,
num_classes=150,
align_corners=False))
head.init_weights()
inputs = [
torch.randn(1, 16, 27, 32),
torch.randn(1, 32, 27, 16),
torch.randn(1, 64, 27, 16),
torch.randn(1, 128, 27, 16)
]
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
outputs = head(inputs)
assert outputs[-1].shape == (1, head.num_classes, 26, 16)
# test loss function in K-Net
fake_label = torch.ones_like(
outputs[-1][:, 0:1, :, :], dtype=torch.int16).long()
loss = head.losses(seg_logit=outputs, seg_label=fake_label)
assert loss['loss_ce.s0'] != torch.zeros_like(loss['loss_ce.s0'])
assert loss['loss_ce.s1'] != torch.zeros_like(loss['loss_ce.s1'])
assert loss['loss_ce.s2'] != torch.zeros_like(loss['loss_ce.s2'])
assert loss['loss_ce.s3'] != torch.zeros_like(loss['loss_ce.s3'])
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.utils import ConfigDict
from mmseg.models.decode_heads import FCNHead, PointHead
from .utils import to_cuda
def test_point_head():
inputs = [torch.randn(1, 32, 45, 45)]
point_head = PointHead(
in_channels=[32], in_index=[0], channels=16, num_classes=19)
assert len(point_head.fcs) == 3
fcn_head = FCNHead(in_channels=32, channels=16, num_classes=19)
if torch.cuda.is_available():
head, inputs = to_cuda(point_head, inputs)
head, inputs = to_cuda(fcn_head, inputs)
prev_output = fcn_head(inputs)
test_cfg = ConfigDict(
subdivision_steps=2, subdivision_num_points=8196, scale_factor=2)
output = point_head.forward_test(inputs, prev_output, None, test_cfg)
assert output.shape == (1, point_head.num_classes, 180, 180)
# test multiple losses case
inputs = [torch.randn(1, 32, 45, 45)]
point_head_multiple_losses = PointHead(
in_channels=[32],
in_index=[0],
channels=16,
num_classes=19,
loss_decode=[
dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2')
])
assert len(point_head_multiple_losses.fcs) == 3
fcn_head_multiple_losses = FCNHead(
in_channels=32,
channels=16,
num_classes=19,
loss_decode=[
dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2')
])
if torch.cuda.is_available():
head, inputs = to_cuda(point_head_multiple_losses, inputs)
head, inputs = to_cuda(fcn_head_multiple_losses, inputs)
prev_output = fcn_head_multiple_losses(inputs)
test_cfg = ConfigDict(
subdivision_steps=2, subdivision_num_points=8196, scale_factor=2)
output = point_head_multiple_losses.forward_test(inputs, prev_output, None,
test_cfg)
assert output.shape == (1, point_head.num_classes, 180, 180)
fake_label = torch.ones([1, 180, 180], dtype=torch.long)
if torch.cuda.is_available():
fake_label = fake_label.cuda()
loss = point_head_multiple_losses.losses(output, fake_label)
assert 'pointloss_1' in loss
assert 'pointloss_2' in loss
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmseg.models.decode_heads import STDCHead
from .utils import to_cuda
def test_stdc_head():
inputs = [torch.randn(1, 32, 21, 21)]
head = STDCHead(
in_channels=32,
channels=8,
num_convs=1,
num_classes=2,
in_index=-1,
loss_decode=[
dict(
type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=1.0)
])
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
outputs = head(inputs)
assert isinstance(outputs, torch.Tensor) and len(outputs) == 1
assert outputs.shape == torch.Size([1, head.num_classes, 21, 21])
fake_label = torch.ones_like(
outputs[:, 0:1, :, :], dtype=torch.int16).long()
loss = head.losses(seg_logit=outputs, seg_label=fake_label)
assert loss['loss_ce'] != torch.zeros_like(loss['loss_ce'])
assert loss['loss_dice'] != torch.zeros_like(loss['loss_dice'])
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmseg.models.losses.cross_entropy_loss import _expand_onehot_labels
@pytest.mark.parametrize('use_sigmoid', [True, False])
@pytest.mark.parametrize('reduction', ('mean', 'sum', 'none'))
@pytest.mark.parametrize('avg_non_ignore', [True, False])
@pytest.mark.parametrize('bce_input_same_dim', [True, False])
def test_ce_loss(use_sigmoid, reduction, avg_non_ignore, bce_input_same_dim):
from mmseg.models import build_loss
# use_mask and use_sigmoid cannot be true at the same time
with pytest.raises(AssertionError):
loss_cfg = dict(
type='CrossEntropyLoss',
use_mask=True,
use_sigmoid=True,
loss_weight=1.0)
build_loss(loss_cfg)
# test loss with simple case for ce/bce
fake_pred = torch.Tensor([[100, -100]])
fake_label = torch.Tensor([1]).long()
loss_cls_cfg = dict(
type='CrossEntropyLoss',
use_sigmoid=use_sigmoid,
loss_weight=1.0,
avg_non_ignore=avg_non_ignore,
loss_name='loss_ce')
loss_cls = build_loss(loss_cls_cfg)
if use_sigmoid:
assert torch.allclose(
loss_cls(fake_pred, fake_label), torch.tensor(100.))
else:
assert torch.allclose(
loss_cls(fake_pred, fake_label), torch.tensor(200.))
# test loss with complicated case for ce/bce
# when avg_non_ignore is False, `avg_factor` would not be calculated
fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
fake_label = torch.ones(2, 8, 8).long()
fake_label[:, 0, 0] = 255
fake_weight = None
# extra test bce loss when pred.shape == label.shape
if use_sigmoid and bce_input_same_dim:
fake_pred = torch.randn(2, 10).float()
fake_label = torch.rand(2, 10).float()
fake_weight = torch.rand(2, 10) # set weight in forward function
fake_label[0, [1, 2, 5, 7]] = 255 # set ignore_index
fake_label[1, [0, 5, 8, 9]] = 255
loss_cls = build_loss(loss_cls_cfg)
loss = loss_cls(
fake_pred, fake_label, weight=fake_weight, ignore_index=255)
if use_sigmoid:
if fake_pred.dim() != fake_label.dim():
fake_label, weight, valid_mask = _expand_onehot_labels(
labels=fake_label,
label_weights=None,
target_shape=fake_pred.shape,
ignore_index=255)
else:
# should mask out the ignored elements
valid_mask = ((fake_label >= 0) & (fake_label != 255)).float()
weight = valid_mask
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
fake_pred,
fake_label.float(),
reduction='none',
weight=fake_weight)
if avg_non_ignore:
avg_factor = valid_mask.sum().item()
torch_loss = (torch_loss * weight).sum() / avg_factor
else:
torch_loss = (torch_loss * weight).mean()
else:
if avg_non_ignore:
torch_loss = torch.nn.functional.cross_entropy(
fake_pred, fake_label, reduction='mean', ignore_index=255)
else:
torch_loss = torch.nn.functional.cross_entropy(
fake_pred, fake_label, reduction='sum',
ignore_index=255) / fake_label.numel()
assert torch.allclose(loss, torch_loss)
if use_sigmoid:
# test loss with complicated case for ce/bce
# when avg_non_ignore is False, `avg_factor` would not be calculated
fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
fake_label = torch.ones(2, 8, 8).long()
fake_label[:, 0, 0] = 255
fake_weight = torch.rand(2, 8, 8)
loss_cls = build_loss(loss_cls_cfg)
loss = loss_cls(
fake_pred, fake_label, weight=fake_weight, ignore_index=255)
if use_sigmoid:
fake_label, weight, valid_mask = _expand_onehot_labels(
labels=fake_label,
label_weights=None,
target_shape=fake_pred.shape,
ignore_index=255)
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
fake_pred,
fake_label.float(),
reduction='none',
weight=fake_weight.unsqueeze(1).expand(fake_pred.shape))
if avg_non_ignore:
avg_factor = valid_mask.sum().item()
torch_loss = (torch_loss * weight).sum() / avg_factor
else:
torch_loss = (torch_loss * weight).mean()
assert torch.allclose(loss, torch_loss)
# test loss with class weights from file
fake_pred = torch.Tensor([[100, -100]])
fake_label = torch.Tensor([1]).long()
import os
import tempfile
import mmcv
import numpy as np
tmp_file = tempfile.NamedTemporaryFile()
mmcv.dump([0.8, 0.2], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
loss_cls_cfg = dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=f'{tmp_file.name}.pkl',
loss_weight=1.0,
loss_name='loss_ce')
loss_cls = build_loss(loss_cls_cfg)
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
np.save(f'{tmp_file.name}.npy', np.array([0.8, 0.2])) # from npy file
loss_cls_cfg = dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=f'{tmp_file.name}.npy',
loss_weight=1.0,
loss_name='loss_ce')
loss_cls = build_loss(loss_cls_cfg)
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
tmp_file.close()
os.remove(f'{tmp_file.name}.pkl')
os.remove(f'{tmp_file.name}.npy')
loss_cls_cfg = dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
loss_cls = build_loss(loss_cls_cfg)
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.))
# test `avg_non_ignore` without ignore index would not affect ce/bce loss
# when reduction='sum'/'none'/'mean'
loss_cls_cfg1 = dict(
type='CrossEntropyLoss',
use_sigmoid=use_sigmoid,
reduction=reduction,
loss_weight=1.0,
avg_non_ignore=True)
loss_cls1 = build_loss(loss_cls_cfg1)
loss_cls_cfg2 = dict(
type='CrossEntropyLoss',
use_sigmoid=use_sigmoid,
reduction=reduction,
loss_weight=1.0,
avg_non_ignore=False)
loss_cls2 = build_loss(loss_cls_cfg2)
assert torch.allclose(
loss_cls1(fake_pred, fake_label, ignore_index=255) / fake_pred.numel(),
loss_cls2(fake_pred, fake_label, ignore_index=255) / fake_pred.numel(),
atol=1e-4)
# test ce/bce loss with ignore index and class weight
# in 5-way classification
if use_sigmoid:
# test bce loss when pred.shape == or != label.shape
if bce_input_same_dim:
fake_pred = torch.randn(2, 10).float()
fake_label = torch.rand(2, 10).float()
class_weight = torch.rand(2, 10)
else:
fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
fake_label = torch.ones(2, 8, 8).long()
class_weight = torch.randn(2, 21, 8, 8)
fake_label, weight, valid_mask = _expand_onehot_labels(
labels=fake_label,
label_weights=None,
target_shape=fake_pred.shape,
ignore_index=-100)
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
fake_pred,
fake_label.float(),
reduction='mean',
pos_weight=class_weight)
else:
fake_pred = torch.randn(2, 5, 10).float() # 5-way classification
fake_label = torch.randint(0, 5, (2, 10)).long()
class_weight = torch.rand(5)
class_weight /= class_weight.sum()
torch_loss = torch.nn.functional.cross_entropy(
fake_pred, fake_label, reduction='sum',
weight=class_weight) / fake_label.numel()
loss_cls_cfg = dict(
type='CrossEntropyLoss',
use_sigmoid=use_sigmoid,
reduction='mean',
class_weight=class_weight,
loss_weight=1.0,
avg_non_ignore=avg_non_ignore)
loss_cls = build_loss(loss_cls_cfg)
# test cross entropy loss has name `loss_ce`
assert loss_cls.loss_name == 'loss_ce'
# test avg_non_ignore is in extra_repr
assert loss_cls.extra_repr() == f'avg_non_ignore={avg_non_ignore}'
loss = loss_cls(fake_pred, fake_label)
assert torch.allclose(loss, torch_loss)
fake_label[0, [1, 2, 5, 7]] = 10 # set ignore_index
fake_label[1, [0, 5, 8, 9]] = 10
loss = loss_cls(fake_pred, fake_label, ignore_index=10)
if use_sigmoid:
if avg_non_ignore:
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
fake_pred[fake_label != 10],
fake_label[fake_label != 10].float(),
pos_weight=class_weight[fake_label != 10],
reduction='mean')
else:
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
fake_pred[fake_label != 10],
fake_label[fake_label != 10].float(),
pos_weight=class_weight[fake_label != 10],
reduction='sum') / fake_label.numel()
else:
if avg_non_ignore:
torch_loss = torch.nn.functional.cross_entropy(
fake_pred,
fake_label,
ignore_index=10,
reduction='sum',
weight=class_weight) / fake_label[fake_label != 10].numel()
else:
torch_loss = torch.nn.functional.cross_entropy(
fake_pred,
fake_label,
ignore_index=10,
reduction='sum',
weight=class_weight) / fake_label.numel()
assert torch.allclose(loss, torch_loss)
@pytest.mark.parametrize('avg_non_ignore', [True, False])
@pytest.mark.parametrize('with_weight', [True, False])
def test_binary_class_ce_loss(avg_non_ignore, with_weight):
from mmseg.models import build_loss
fake_pred = torch.rand(3, 1, 10, 10)
fake_label = torch.randint(0, 2, (3, 10, 10))
fake_weight = torch.rand(3, 10, 10)
valid_mask = ((fake_label >= 0) & (fake_label != 255)).float()
weight = valid_mask
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
fake_pred,
fake_label.unsqueeze(1).float(),
reduction='none',
weight=fake_weight.unsqueeze(1).float() if with_weight else None)
if avg_non_ignore:
eps = torch.finfo(torch.float32).eps
avg_factor = valid_mask.sum().item()
torch_loss = (torch_loss * weight.unsqueeze(1)).sum() / (
avg_factor + eps)
else:
torch_loss = (torch_loss * weight.unsqueeze(1)).mean()
loss_cls_cfg = dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0,
avg_non_ignore=avg_non_ignore,
reduction='mean',
loss_name='loss_ce')
loss_cls = build_loss(loss_cls_cfg)
loss = loss_cls(
fake_pred,
fake_label,
weight=fake_weight if with_weight else None,
ignore_index=255)
assert torch.allclose(loss, torch_loss)
# Copyright (c) OpenMMLab. All rights reserved.
import torch
def test_dice_lose():
from mmseg.models import build_loss
# test dice loss with loss_type = 'multi_class'
loss_cfg = dict(
type='DiceLoss',
reduction='none',
class_weight=[1.0, 2.0, 3.0],
loss_weight=1.0,
ignore_index=1,
loss_name='loss_dice')
dice_loss = build_loss(loss_cfg)
logits = torch.rand(8, 3, 4, 4)
labels = (torch.rand(8, 4, 4) * 3).long()
dice_loss(logits, labels)
# test loss with class weights from file
import os
import tempfile
import mmcv
import numpy as np
tmp_file = tempfile.NamedTemporaryFile()
mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
loss_cfg = dict(
type='DiceLoss',
reduction='none',
class_weight=f'{tmp_file.name}.pkl',
loss_weight=1.0,
ignore_index=1,
loss_name='loss_dice')
dice_loss = build_loss(loss_cfg)
dice_loss(logits, labels, ignore_index=None)
np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file
loss_cfg = dict(
type='DiceLoss',
reduction='none',
class_weight=f'{tmp_file.name}.pkl',
loss_weight=1.0,
ignore_index=1,
loss_name='loss_dice')
dice_loss = build_loss(loss_cfg)
dice_loss(logits, labels, ignore_index=None)
tmp_file.close()
os.remove(f'{tmp_file.name}.pkl')
os.remove(f'{tmp_file.name}.npy')
# test dice loss with loss_type = 'binary'
loss_cfg = dict(
type='DiceLoss',
smooth=2,
exponent=3,
reduction='sum',
loss_weight=1.0,
ignore_index=0,
loss_name='loss_dice')
dice_loss = build_loss(loss_cfg)
logits = torch.rand(8, 2, 4, 4)
labels = (torch.rand(8, 4, 4) * 2).long()
dice_loss(logits, labels)
# test dice loss has name `loss_dice`
loss_cfg = dict(
type='DiceLoss',
smooth=2,
exponent=3,
reduction='sum',
loss_weight=1.0,
ignore_index=0,
loss_name='loss_dice')
dice_loss = build_loss(loss_cfg)
assert dice_loss.loss_name == 'loss_dice'
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
import torch.nn.functional as F
from mmseg.models import build_loss
# test focal loss with use_sigmoid=False
def test_use_sigmoid():
# can't init with use_sigmoid=True
with pytest.raises(AssertionError):
loss_cfg = dict(type='FocalLoss', use_sigmoid=False)
build_loss(loss_cfg)
# can't forward with use_sigmoid=True
with pytest.raises(NotImplementedError):
loss_cfg = dict(type='FocalLoss', use_sigmoid=True)
focal_loss = build_loss(loss_cfg)
focal_loss.use_sigmoid = False
fake_pred = torch.rand(3, 4, 5, 6)
fake_target = torch.randint(0, 4, (3, 5, 6))
focal_loss(fake_pred, fake_target)
# reduction type must be 'none', 'mean' or 'sum'
def test_wrong_reduction_type():
# can't init with wrong reduction
with pytest.raises(AssertionError):
loss_cfg = dict(type='FocalLoss', reduction='test')
build_loss(loss_cfg)
# can't forward with wrong reduction override
with pytest.raises(AssertionError):
loss_cfg = dict(type='FocalLoss')
focal_loss = build_loss(loss_cfg)
fake_pred = torch.rand(3, 4, 5, 6)
fake_target = torch.randint(0, 4, (3, 5, 6))
focal_loss(fake_pred, fake_target, reduction_override='test')
# test focal loss can handle input parameters with
# unacceptable types
def test_unacceptable_parameters():
with pytest.raises(AssertionError):
loss_cfg = dict(type='FocalLoss', gamma='test')
build_loss(loss_cfg)
with pytest.raises(AssertionError):
loss_cfg = dict(type='FocalLoss', alpha='test')
build_loss(loss_cfg)
with pytest.raises(AssertionError):
loss_cfg = dict(type='FocalLoss', class_weight='test')
build_loss(loss_cfg)
with pytest.raises(AssertionError):
loss_cfg = dict(type='FocalLoss', loss_weight='test')
build_loss(loss_cfg)
with pytest.raises(AssertionError):
loss_cfg = dict(type='FocalLoss', loss_name=123)
build_loss(loss_cfg)
# test if focal loss can be correctly initialize
def test_init_focal_loss():
loss_cfg = dict(
type='FocalLoss',
use_sigmoid=True,
gamma=3.0,
alpha=3.0,
class_weight=[1, 2, 3, 4],
reduction='sum')
focal_loss = build_loss(loss_cfg)
assert focal_loss.use_sigmoid is True
assert focal_loss.gamma == 3.0
assert focal_loss.alpha == 3.0
assert focal_loss.reduction == 'sum'
assert focal_loss.class_weight == [1, 2, 3, 4]
assert focal_loss.loss_weight == 1.0
assert focal_loss.loss_name == 'loss_focal'
# test reduction override
def test_reduction_override():
loss_cfg = dict(type='FocalLoss', reduction='mean')
focal_loss = build_loss(loss_cfg)
fake_pred = torch.rand(3, 4, 5, 6)
fake_target = torch.randint(0, 4, (3, 5, 6))
loss = focal_loss(fake_pred, fake_target, reduction_override='none')
assert loss.shape == fake_pred.shape
# test wrong pred and target shape
def test_wrong_pred_and_target_shape():
loss_cfg = dict(type='FocalLoss')
focal_loss = build_loss(loss_cfg)
fake_pred = torch.rand(3, 4, 5, 6)
fake_target = torch.randint(0, 4, (3, 2, 2))
fake_target = F.one_hot(fake_target, num_classes=4)
fake_target = fake_target.permute(0, 3, 1, 2)
with pytest.raises(AssertionError):
focal_loss(fake_pred, fake_target)
# test forward with different shape of target
def test_forward_with_different_shape_of_target():
loss_cfg = dict(type='FocalLoss')
focal_loss = build_loss(loss_cfg)
fake_pred = torch.rand(3, 4, 5, 6)
fake_target = torch.randint(0, 4, (3, 5, 6))
loss1 = focal_loss(fake_pred, fake_target)
fake_target = F.one_hot(fake_target, num_classes=4)
fake_target = fake_target.permute(0, 3, 1, 2)
loss2 = focal_loss(fake_pred, fake_target)
assert loss1 == loss2
# test forward with weight
def test_forward_with_weight():
loss_cfg = dict(type='FocalLoss')
focal_loss = build_loss(loss_cfg)
fake_pred = torch.rand(3, 4, 5, 6)
fake_target = torch.randint(0, 4, (3, 5, 6))
weight = torch.rand(3 * 5 * 6, 1)
loss1 = focal_loss(fake_pred, fake_target, weight=weight)
weight2 = weight.view(-1)
loss2 = focal_loss(fake_pred, fake_target, weight=weight2)
weight3 = weight.expand(3 * 5 * 6, 4)
loss3 = focal_loss(fake_pred, fake_target, weight=weight3)
assert loss1 == loss2 == loss3
# test none reduction type
def test_none_reduction_type():
loss_cfg = dict(type='FocalLoss', reduction='none')
focal_loss = build_loss(loss_cfg)
fake_pred = torch.rand(3, 4, 5, 6)
fake_target = torch.randint(0, 4, (3, 5, 6))
loss = focal_loss(fake_pred, fake_target)
assert loss.shape == fake_pred.shape
# test the usage of class weight
def test_class_weight():
loss_cfg_cw = dict(
type='FocalLoss', reduction='none', class_weight=[1.0, 2.0, 3.0, 4.0])
loss_cfg = dict(type='FocalLoss', reduction='none')
focal_loss_cw = build_loss(loss_cfg_cw)
focal_loss = build_loss(loss_cfg)
fake_pred = torch.rand(3, 4, 5, 6)
fake_target = torch.randint(0, 4, (3, 5, 6))
loss_cw = focal_loss_cw(fake_pred, fake_target)
loss = focal_loss(fake_pred, fake_target)
weight = torch.tensor([1, 2, 3, 4]).view(1, 4, 1, 1)
assert (loss * weight == loss_cw).all()
# test ignore index
def test_ignore_index():
loss_cfg = dict(type='FocalLoss', reduction='none')
# ignore_index within C classes
focal_loss = build_loss(loss_cfg)
fake_pred = torch.rand(3, 5, 5, 6)
fake_target = torch.randint(0, 4, (3, 5, 6))
dim1 = torch.randint(0, 3, (4, ))
dim2 = torch.randint(0, 5, (4, ))
dim3 = torch.randint(0, 6, (4, ))
fake_target[dim1, dim2, dim3] = 4
loss1 = focal_loss(fake_pred, fake_target, ignore_index=4)
one_hot_target = F.one_hot(fake_target, num_classes=5)
one_hot_target = one_hot_target.permute(0, 3, 1, 2)
loss2 = focal_loss(fake_pred, one_hot_target, ignore_index=4)
assert (loss1 == loss2).all()
assert (loss1[dim1, :, dim2, dim3] == 0).all()
assert (loss2[dim1, :, dim2, dim3] == 0).all()
fake_pred = torch.rand(3, 4, 5, 6)
fake_target = torch.randint(0, 4, (3, 5, 6))
loss1 = focal_loss(fake_pred, fake_target, ignore_index=2)
one_hot_target = F.one_hot(fake_target, num_classes=4)
one_hot_target = one_hot_target.permute(0, 3, 1, 2)
loss2 = focal_loss(fake_pred, one_hot_target, ignore_index=2)
ignore_mask = one_hot_target == 2
assert (loss1 == loss2).all()
assert torch.sum(loss1 * ignore_mask) == 0
assert torch.sum(loss2 * ignore_mask) == 0
# ignore index is not in prediction's classes
fake_pred = torch.rand(3, 4, 5, 6)
fake_target = torch.randint(0, 4, (3, 5, 6))
dim1 = torch.randint(0, 3, (4, ))
dim2 = torch.randint(0, 5, (4, ))
dim3 = torch.randint(0, 6, (4, ))
fake_target[dim1, dim2, dim3] = 255
loss1 = focal_loss(fake_pred, fake_target, ignore_index=255)
assert (loss1[dim1, :, dim2, dim3] == 0).all()
# test list alpha
def test_alpha():
loss_cfg = dict(type='FocalLoss')
focal_loss = build_loss(loss_cfg)
alpha_float = 0.4
alpha = [0.4, 0.4, 0.4, 0.4]
alpha2 = [0.1, 0.3, 0.2, 0.1]
fake_pred = torch.rand(3, 4, 5, 6)
fake_target = torch.randint(0, 4, (3, 5, 6))
focal_loss.alpha = alpha_float
loss1 = focal_loss(fake_pred, fake_target)
focal_loss.alpha = alpha
loss2 = focal_loss(fake_pred, fake_target)
assert loss1 == loss2
focal_loss.alpha = alpha2
focal_loss(fake_pred, fake_target)
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
def test_lovasz_loss():
from mmseg.models import build_loss
# loss_type should be 'binary' or 'multi_class'
with pytest.raises(AssertionError):
loss_cfg = dict(
type='LovaszLoss',
loss_type='Binary',
reduction='none',
loss_weight=1.0,
loss_name='loss_lovasz')
build_loss(loss_cfg)
# reduction should be 'none' when per_image is False.
with pytest.raises(AssertionError):
loss_cfg = dict(
type='LovaszLoss',
loss_type='multi_class',
loss_name='loss_lovasz')
build_loss(loss_cfg)
# test lovasz loss with loss_type = 'multi_class' and per_image = False
loss_cfg = dict(
type='LovaszLoss',
reduction='none',
loss_weight=1.0,
loss_name='loss_lovasz')
lovasz_loss = build_loss(loss_cfg)
logits = torch.rand(1, 3, 4, 4)
labels = (torch.rand(1, 4, 4) * 2).long()
lovasz_loss(logits, labels)
# test lovasz loss with loss_type = 'multi_class' and per_image = True
loss_cfg = dict(
type='LovaszLoss',
per_image=True,
reduction='mean',
class_weight=[1.0, 2.0, 3.0],
loss_weight=1.0,
loss_name='loss_lovasz')
lovasz_loss = build_loss(loss_cfg)
logits = torch.rand(1, 3, 4, 4)
labels = (torch.rand(1, 4, 4) * 2).long()
lovasz_loss(logits, labels, ignore_index=None)
# test loss with class weights from file
import os
import tempfile
import mmcv
import numpy as np
tmp_file = tempfile.NamedTemporaryFile()
mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
loss_cfg = dict(
type='LovaszLoss',
per_image=True,
reduction='mean',
class_weight=f'{tmp_file.name}.pkl',
loss_weight=1.0,
loss_name='loss_lovasz')
lovasz_loss = build_loss(loss_cfg)
lovasz_loss(logits, labels, ignore_index=None)
np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file
loss_cfg = dict(
type='LovaszLoss',
per_image=True,
reduction='mean',
class_weight=f'{tmp_file.name}.npy',
loss_weight=1.0,
loss_name='loss_lovasz')
lovasz_loss = build_loss(loss_cfg)
lovasz_loss(logits, labels, ignore_index=None)
tmp_file.close()
os.remove(f'{tmp_file.name}.pkl')
os.remove(f'{tmp_file.name}.npy')
# test lovasz loss with loss_type = 'binary' and per_image = False
loss_cfg = dict(
type='LovaszLoss',
loss_type='binary',
reduction='none',
loss_weight=1.0,
loss_name='loss_lovasz')
lovasz_loss = build_loss(loss_cfg)
logits = torch.rand(2, 4, 4)
labels = (torch.rand(2, 4, 4)).long()
lovasz_loss(logits, labels)
# test lovasz loss with loss_type = 'binary' and per_image = True
loss_cfg = dict(
type='LovaszLoss',
loss_type='binary',
per_image=True,
reduction='mean',
loss_weight=1.0,
loss_name='loss_lovasz')
lovasz_loss = build_loss(loss_cfg)
logits = torch.rand(2, 4, 4)
labels = (torch.rand(2, 4, 4)).long()
lovasz_loss(logits, labels, ignore_index=None)
# test lovasz loss has name `loss_lovasz`
loss_cfg = dict(
type='LovaszLoss',
loss_type='binary',
per_image=True,
reduction='mean',
loss_weight=1.0,
loss_name='loss_lovasz')
lovasz_loss = build_loss(loss_cfg)
assert lovasz_loss.loss_name == 'loss_lovasz'
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment