diff --git a/mmseg/core/__init__.py b/mmseg/core/__init__.py index 0f2fcf13c1a2cf30a48412d9efaabbbe7b227e40..39fa62423356dd6d351f5f3e38372cf989ebae74 100644 --- a/mmseg/core/__init__.py +++ b/mmseg/core/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .builder import build_optimizer, build_optimizer_constructor +from .data_structures import * # noqa: F401, F403 from .evaluation import * # noqa: F401, F403 from .optimizers import * # noqa: F401, F403 from .seg import * # noqa: F401, F403 diff --git a/mmseg/core/data_structures/__init__.py b/mmseg/core/data_structures/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b73e72190a42a8920000c2523483b1d8ed35c55d --- /dev/null +++ b/mmseg/core/data_structures/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .seg_data_sample import SegDataSample + +__all__ = ['SegDataSample'] diff --git a/mmseg/core/data_structures/seg_data_sample.py b/mmseg/core/data_structures/seg_data_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..23b60d5b6a25fce71728e4c5bf45eb5967af562c --- /dev/null +++ b/mmseg/core/data_structures/seg_data_sample.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.data import BaseDataElement, PixelData + + +class SegDataSample(BaseDataElement): + """A data structure interface of MMSegmentation. They are used as + interfaces between different components. + + The attributes in ``SegDataSample`` are divided into several parts: + + - ``gt_sem_seg``(PixelData): Ground truth of semantic segmentation. + - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation. + - ``seg_logits``(PixelData): Predicted logits of semantic segmentation. + + Examples: + >>> import torch + >>> import numpy as np + >>> from mmengine.data import PixelData + >>> from mmseg.core import SegDataSample + + >>> data_sample = SegDataSample() + >>> img_meta = dict(img_shape=(4, 4, 3), + ... pad_shape=(4, 4, 3)) + >>> gt_segmentations = PixelData(metainfo=img_meta) + >>> gt_segmentations.gt_sem_seg = torch.randint(0, 2, (1, 4, 4)) + >>> data_sample.gt_segmentations = gt_segmentations + >>> assert 'img_shape' in data_sample.gt_segmentations.metainfo_keys() + >>> data_sample.gt_segmentations + (4, 4) + >>> print(data_sample) + <SegDataSample( + + META INFORMATION + + DATA FIELDS + gt_segmentations: <PixelData( + + META INFORMATION + img_shape: (4, 4, 3) + pad_shape: (4, 4, 3) + + DATA FIELDS + gt_sem_seg: tensor([[[1, 1, 1, 0], + [1, 0, 1, 1], + [1, 1, 1, 1], + [0, 1, 0, 1]]]) + ) at 0x1c2b4156460> + ) at 0x1c2aae44d60> + + >>> data_sample = SegDataSample() + >>> gt_sem_seg_data = dict(sem_seg=torch.rand(1, 4, 4)) + >>> gt_sem_seg = PixelData(**gt_sem_seg_data) + >>> data_sample.gt_sem_seg = gt_sem_seg + >>> assert 'gt_sem_seg' in data_sample + >>> assert 'sem_seg' in data_sample.gt_sem_seg + """ + + @property + def gt_sem_seg(self) -> PixelData: + return self._gt_sem_seg + + @gt_sem_seg.setter + def gt_sem_seg(self, value: PixelData) -> None: + self.set_field(value, '_gt_sem_seg', dtype=PixelData) + + @gt_sem_seg.deleter + def gt_sem_seg(self) -> None: + del self._gt_sem_seg + + @property + def pred_sem_seg(self) -> PixelData: + return self._pred_sem_seg + + @pred_sem_seg.setter + def pred_sem_seg(self, value: PixelData) -> None: + self.set_field(value, '_pred_sem_seg', dtype=PixelData) + + @pred_sem_seg.deleter + def pred_sem_seg(self) -> None: + del self._pred_sem_seg + + @property + def seg_logits(self) -> PixelData: + return self._seg_logits + + @seg_logits.setter + def seg_logits(self, value: PixelData) -> None: + self.set_field(value, '_seg_logits', dtype=PixelData) + + @seg_logits.deleter + def seg_logits(self) -> None: + del self._seg_logits diff --git a/tests/test_core/test_seg_data_sample.py b/tests/test_core/test_seg_data_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..c416e9ca679b66a4ed484cc42053f8fc7ef7b7c2 --- /dev/null +++ b/tests/test_core/test_seg_data_sample.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import numpy as np +import pytest +import torch +from mmengine.data import PixelData + +from mmseg.core import SegDataSample + + +def _equal(a, b): + if isinstance(a, (torch.Tensor, np.ndarray)): + return (a == b).all() + else: + return a == b + + +class TestSegDataSample(TestCase): + + def test_init(self): + meta_info = dict( + img_size=[256, 256], + scale_factor=np.array([1.5, 1.5]), + img_shape=torch.rand(4)) + + seg_data_sample = SegDataSample(metainfo=meta_info) + assert 'img_size' in seg_data_sample + assert seg_data_sample.img_size == [256, 256] + assert seg_data_sample.get('img_size') == [256, 256] + + def test_setter(self): + seg_data_sample = SegDataSample() + + # test gt_sem_seg + gt_sem_seg_data = dict(sem_seg=torch.rand(5, 4, 2)) + gt_sem_seg = PixelData(**gt_sem_seg_data) + seg_data_sample.gt_sem_seg = gt_sem_seg + assert 'gt_sem_seg' in seg_data_sample + assert _equal(seg_data_sample.gt_sem_seg.sem_seg, + gt_sem_seg_data['sem_seg']) + + # test pred_sem_seg + pred_sem_seg_data = dict(sem_seg=torch.rand(5, 4, 2)) + pred_sem_seg = PixelData(**pred_sem_seg_data) + seg_data_sample.pred_sem_seg = pred_sem_seg + assert 'pred_sem_seg' in seg_data_sample + assert _equal(seg_data_sample.pred_sem_seg.sem_seg, + pred_sem_seg_data['sem_seg']) + + # test seg_logits + seg_logits_data = dict(sem_seg=torch.rand(5, 4, 2)) + seg_logits = PixelData(**seg_logits_data) + seg_data_sample.seg_logits = seg_logits + assert 'seg_logits' in seg_data_sample + assert _equal(seg_data_sample.seg_logits.sem_seg, + seg_logits_data['sem_seg']) + + # test type error + with pytest.raises(AssertionError): + seg_data_sample.gt_sem_seg = torch.rand(2, 4) + + with pytest.raises(AssertionError): + seg_data_sample.pred_sem_seg = torch.rand(2, 4) + + with pytest.raises(AssertionError): + seg_data_sample.seg_logits = torch.rand(2, 4) + + def test_deleter(self): + seg_data_sample = SegDataSample() + + pred_sem_seg_data = dict(sem_seg=torch.rand(5, 4, 2)) + pred_sem_seg = PixelData(**pred_sem_seg_data) + seg_data_sample.pred_sem_seg = pred_sem_seg + assert 'pred_sem_seg' in seg_data_sample + del seg_data_sample.pred_sem_seg + assert 'pred_sem_seg' not in seg_data_sample