Skip to content
Snippets Groups Projects
Commit d33af521 authored by xiexinch's avatar xiexinch
Browse files

fix ut

parent d0b35cda
No related branches found
No related tags found
No related merge requests found
......@@ -48,9 +48,6 @@ class SegDataPreProcessor(BaseDataPreprocessor):
rgb_to_bgr (bool): whether to convert image from RGB to RGB.
Defaults to False.
batch_augments (list[dict], optional): Batch-level augmentations
train_cfg (dict, optional): The padding size config in training, if
not specify, will use `size` and `size_divisor` params as default.
Defaults to None, only supports keys `size` or `size_divisor`.
test_cfg (dict, optional): The padding size config in testing, if not
specify, will use `size` and `size_divisor` params as default.
Defaults to None, only supports keys `size` or `size_divisor`.
......@@ -67,7 +64,6 @@ class SegDataPreProcessor(BaseDataPreprocessor):
bgr_to_rgb: bool = False,
rgb_to_bgr: bool = False,
batch_augments: Optional[List[dict]] = None,
train_cfg: dict = None,
test_cfg: dict = None,
):
super().__init__()
......@@ -96,10 +92,8 @@ class SegDataPreProcessor(BaseDataPreprocessor):
# TODO: support batch augmentations.
self.batch_augments = batch_augments
# Support different padding methods in training and testing
default_size_cfg = dict(size=size, size_divisor=size_divisor)
self.train_cfg = train_cfg if train_cfg else default_size_cfg
self.test_cfg = test_cfg if test_cfg else default_size_cfg
# Support different padding methods in testing
self.test_cfg = test_cfg
def forward(self, data: dict, training: bool = False) -> Dict[str, Any]:
"""Perform normalization、padding and bgr2rgb conversion based on
......@@ -126,24 +120,31 @@ class SegDataPreProcessor(BaseDataPreprocessor):
if training:
assert data_samples is not None, ('During training, ',
'`data_samples` must be define.')
inputs, data_samples = stack_batch(
inputs=inputs,
data_samples=data_samples,
size=self.size,
size_divisor=self.size_divisor,
pad_val=self.pad_val,
seg_pad_val=self.seg_pad_val)
if self.batch_augments is not None:
inputs, data_samples = self.batch_augments(
inputs, data_samples)
else:
assert len(inputs) == 1, (
'Batch inference is not support currently, '
'as the image size might be different in a batch')
size_cfg = self.train_cfg if training else self.test_cfg
size = size_cfg.get('size', None)
size_divisor = size_cfg.get('size_divisor', None)
inputs, data_samples = stack_batch(
inputs=inputs,
data_samples=data_samples,
size=size,
size_divisor=size_divisor,
pad_val=self.pad_val,
seg_pad_val=self.seg_pad_val)
if self.batch_augments is not None:
inputs, data_samples = self.batch_augments(inputs, data_samples)
# pad images when testing
if self.test_cfg:
inputs, data_samples = stack_batch(
inputs=inputs,
data_samples=data_samples,
size=self.test_cfg.get('size', None),
size_divisor=self.test_cfg.get('size_divisor', None),
pad_val=self.pad_val,
seg_pad_val=self.seg_pad_val)
else:
inputs = torch.stack(inputs, dim=0)
return dict(inputs=inputs, data_samples=data_samples)
......@@ -192,7 +192,8 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
'pred_sem_seg':
PixelData(**{'data': i_seg_pred}),
'gt_sem_seg':
PixelData(**{'data': i_gt_sem_seg})
PixelData() if only_prediction else PixelData(
**{'data': i_gt_sem_seg})
})
return data_samples
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment