From f6d329106426417bcf559a31fdfd820677bff41e Mon Sep 17 00:00:00 2001 From: "linfangjian.vendor" <linfangjian@pjlab.org.cn> Date: Wed, 25 May 2022 01:58:37 +0000 Subject: [PATCH] [Refactor] Refactor RandomCrop and PhotoMetricDistortion --- configs/_base_/datasets/ade20k.py | 15 +- configs/_base_/datasets/ade20k_640x640.py | 15 +- configs/_base_/datasets/chase_db1.py | 15 +- configs/_base_/datasets/cityscapes.py | 15 +- .../_base_/datasets/cityscapes_1024x1024.py | 15 +- configs/_base_/datasets/cityscapes_768x768.py | 15 +- configs/_base_/datasets/cityscapes_769x769.py | 15 +- configs/_base_/datasets/cityscapes_832x832.py | 15 +- configs/_base_/datasets/coco-stuff10k.py | 15 +- configs/_base_/datasets/coco-stuff164k.py | 15 +- configs/_base_/datasets/drive.py | 15 +- configs/_base_/datasets/hrf.py | 15 +- configs/_base_/datasets/isaid.py | 15 +- configs/_base_/datasets/loveda.py | 15 +- configs/_base_/datasets/pascal_context.py | 15 +- configs/_base_/datasets/pascal_context_59.py | 15 +- configs/_base_/datasets/pascal_voc12.py | 15 +- configs/_base_/datasets/potsdam.py | 15 +- configs/_base_/datasets/stare.py | 15 +- configs/_base_/datasets/vaihingen.py | 15 +- mmseg/datasets/pipelines/transforms.py | 205 ++++++++++++++---- .../test_pipelines/test_transforms.py | 64 ++++++ 22 files changed, 501 insertions(+), 68 deletions(-) create mode 100644 tests/test_datasets/test_pipelines/test_transforms.py diff --git a/configs/_base_/datasets/ade20k.py b/configs/_base_/datasets/ade20k.py index efc8b4bb..c7742285 100644 --- a/configs/_base_/datasets/ade20k.py +++ b/configs/_base_/datasets/ade20k.py @@ -8,7 +8,20 @@ train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), - dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict( + type='TransformBroadcaster', + mapping={ + 'img': ['img', 'gt_semantic_seg'], + 'img_shape': [..., 'img_shape'] + }, + auto_remap=True, + share_random_params=True, + transforms=[ + dict( + type='mmseg.RandomCrop', + crop_size=crop_size, + cat_max_ratio=0.75), + ]), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), diff --git a/configs/_base_/datasets/ade20k_640x640.py b/configs/_base_/datasets/ade20k_640x640.py index 14a4bb09..3907f6fe 100644 --- a/configs/_base_/datasets/ade20k_640x640.py +++ b/configs/_base_/datasets/ade20k_640x640.py @@ -8,7 +8,20 @@ train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='Resize', img_scale=(2560, 640), ratio_range=(0.5, 2.0)), - dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict( + type='TransformBroadcaster', + mapping={ + 'img': ['img', 'gt_semantic_seg'], + 'img_shape': [..., 'img_shape'] + }, + auto_remap=True, + share_random_params=True, + transforms=[ + dict( + type='mmseg.RandomCrop', + crop_size=crop_size, + cat_max_ratio=0.75), + ]), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), diff --git a/configs/_base_/datasets/chase_db1.py b/configs/_base_/datasets/chase_db1.py index 298594ea..f8eb4fdb 100644 --- a/configs/_base_/datasets/chase_db1.py +++ b/configs/_base_/datasets/chase_db1.py @@ -9,7 +9,20 @@ train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), - dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict( + type='TransformBroadcaster', + mapping={ + 'img': ['img', 'gt_semantic_seg'], + 'img_shape': [..., 'img_shape'] + }, + auto_remap=True, + share_random_params=True, + transforms=[ + dict( + type='mmseg.RandomCrop', + crop_size=crop_size, + cat_max_ratio=0.75), + ]), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), diff --git a/configs/_base_/datasets/cityscapes.py b/configs/_base_/datasets/cityscapes.py index f21867c6..4a645e1b 100644 --- a/configs/_base_/datasets/cityscapes.py +++ b/configs/_base_/datasets/cityscapes.py @@ -8,7 +8,20 @@ train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), - dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict( + type='TransformBroadcaster', + mapping={ + 'img': ['img', 'gt_semantic_seg'], + 'img_shape': [..., 'img_shape'] + }, + auto_remap=True, + share_random_params=True, + transforms=[ + dict( + type='mmseg.RandomCrop', + crop_size=crop_size, + cat_max_ratio=0.75), + ]), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), diff --git a/configs/_base_/datasets/cityscapes_1024x1024.py b/configs/_base_/datasets/cityscapes_1024x1024.py index f98d9297..57d09289 100644 --- a/configs/_base_/datasets/cityscapes_1024x1024.py +++ b/configs/_base_/datasets/cityscapes_1024x1024.py @@ -6,7 +6,20 @@ train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), - dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict( + type='TransformBroadcaster', + mapping={ + 'img': ['img', 'gt_semantic_seg'], + 'img_shape': [..., 'img_shape'] + }, + auto_remap=True, + share_random_params=True, + transforms=[ + dict( + type='mmseg.RandomCrop', + crop_size=crop_size, + cat_max_ratio=0.75), + ]), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), diff --git a/configs/_base_/datasets/cityscapes_768x768.py b/configs/_base_/datasets/cityscapes_768x768.py index fde9d7c7..8735ef55 100644 --- a/configs/_base_/datasets/cityscapes_768x768.py +++ b/configs/_base_/datasets/cityscapes_768x768.py @@ -6,7 +6,20 @@ train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)), - dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict( + type='TransformBroadcaster', + mapping={ + 'img': ['img', 'gt_semantic_seg'], + 'img_shape': [..., 'img_shape'] + }, + auto_remap=True, + share_random_params=True, + transforms=[ + dict( + type='mmseg.RandomCrop', + crop_size=crop_size, + cat_max_ratio=0.75), + ]), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), diff --git a/configs/_base_/datasets/cityscapes_769x769.py b/configs/_base_/datasets/cityscapes_769x769.py index 336c7b25..d04ac0a8 100644 --- a/configs/_base_/datasets/cityscapes_769x769.py +++ b/configs/_base_/datasets/cityscapes_769x769.py @@ -6,7 +6,20 @@ train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)), - dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict( + type='TransformBroadcaster', + mapping={ + 'img': ['img', 'gt_semantic_seg'], + 'img_shape': [..., 'img_shape'] + }, + auto_remap=True, + share_random_params=True, + transforms=[ + dict( + type='mmseg.RandomCrop', + crop_size=crop_size, + cat_max_ratio=0.75), + ]), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), diff --git a/configs/_base_/datasets/cityscapes_832x832.py b/configs/_base_/datasets/cityscapes_832x832.py index b9325cc0..d65c16f5 100644 --- a/configs/_base_/datasets/cityscapes_832x832.py +++ b/configs/_base_/datasets/cityscapes_832x832.py @@ -6,7 +6,20 @@ train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), - dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict( + type='TransformBroadcaster', + mapping={ + 'img': ['img', 'gt_semantic_seg'], + 'img_shape': [..., 'img_shape'] + }, + auto_remap=True, + share_random_params=True, + transforms=[ + dict( + type='mmseg.RandomCrop', + crop_size=crop_size, + cat_max_ratio=0.75), + ]), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), diff --git a/configs/_base_/datasets/coco-stuff10k.py b/configs/_base_/datasets/coco-stuff10k.py index ec049692..ceec06dc 100644 --- a/configs/_base_/datasets/coco-stuff10k.py +++ b/configs/_base_/datasets/coco-stuff10k.py @@ -8,7 +8,20 @@ train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), - dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict( + type='TransformBroadcaster', + mapping={ + 'img': ['img', 'gt_semantic_seg'], + 'img_shape': [..., 'img_shape'] + }, + auto_remap=True, + share_random_params=True, + transforms=[ + dict( + type='mmseg.RandomCrop', + crop_size=crop_size, + cat_max_ratio=0.75), + ]), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), diff --git a/configs/_base_/datasets/coco-stuff164k.py b/configs/_base_/datasets/coco-stuff164k.py index a6a38f2a..29a33894 100644 --- a/configs/_base_/datasets/coco-stuff164k.py +++ b/configs/_base_/datasets/coco-stuff164k.py @@ -8,7 +8,20 @@ train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), - dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict( + type='TransformBroadcaster', + mapping={ + 'img': ['img', 'gt_semantic_seg'], + 'img_shape': [..., 'img_shape'] + }, + auto_remap=True, + share_random_params=True, + transforms=[ + dict( + type='mmseg.RandomCrop', + crop_size=crop_size, + cat_max_ratio=0.75), + ]), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), diff --git a/configs/_base_/datasets/drive.py b/configs/_base_/datasets/drive.py index 06e8ff60..6b00bc75 100644 --- a/configs/_base_/datasets/drive.py +++ b/configs/_base_/datasets/drive.py @@ -9,7 +9,20 @@ train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), - dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict( + type='TransformBroadcaster', + mapping={ + 'img': ['img', 'gt_semantic_seg'], + 'img_shape': [..., 'img_shape'] + }, + auto_remap=True, + share_random_params=True, + transforms=[ + dict( + type='mmseg.RandomCrop', + crop_size=crop_size, + cat_max_ratio=0.75), + ]), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), diff --git a/configs/_base_/datasets/hrf.py b/configs/_base_/datasets/hrf.py index 242d790e..2c1ad741 100644 --- a/configs/_base_/datasets/hrf.py +++ b/configs/_base_/datasets/hrf.py @@ -9,7 +9,20 @@ train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), - dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict( + type='TransformBroadcaster', + mapping={ + 'img': ['img', 'gt_semantic_seg'], + 'img_shape': [..., 'img_shape'] + }, + auto_remap=True, + share_random_params=True, + transforms=[ + dict( + type='mmseg.RandomCrop', + crop_size=crop_size, + cat_max_ratio=0.75), + ]), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), diff --git a/configs/_base_/datasets/isaid.py b/configs/_base_/datasets/isaid.py index 8e4c26ab..29e731fb 100644 --- a/configs/_base_/datasets/isaid.py +++ b/configs/_base_/datasets/isaid.py @@ -16,7 +16,20 @@ train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=(896, 896), ratio_range=(0.5, 2.0)), - dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict( + type='TransformBroadcaster', + mapping={ + 'img': ['img', 'gt_semantic_seg'], + 'img_shape': [..., 'img_shape'] + }, + auto_remap=True, + share_random_params=True, + transforms=[ + dict( + type='mmseg.RandomCrop', + crop_size=crop_size, + cat_max_ratio=0.75), + ]), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), diff --git a/configs/_base_/datasets/loveda.py b/configs/_base_/datasets/loveda.py index e5533569..bcdc4f15 100644 --- a/configs/_base_/datasets/loveda.py +++ b/configs/_base_/datasets/loveda.py @@ -8,7 +8,20 @@ train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), - dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict( + type='TransformBroadcaster', + mapping={ + 'img': ['img', 'gt_semantic_seg'], + 'img_shape': [..., 'img_shape'] + }, + auto_remap=True, + share_random_params=True, + transforms=[ + dict( + type='mmseg.RandomCrop', + crop_size=crop_size, + cat_max_ratio=0.75), + ]), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), diff --git a/configs/_base_/datasets/pascal_context.py b/configs/_base_/datasets/pascal_context.py index ff65bad1..0f803a6b 100644 --- a/configs/_base_/datasets/pascal_context.py +++ b/configs/_base_/datasets/pascal_context.py @@ -11,7 +11,20 @@ train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), - dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict( + type='TransformBroadcaster', + mapping={ + 'img': ['img', 'gt_semantic_seg'], + 'img_shape': [..., 'img_shape'] + }, + auto_remap=True, + share_random_params=True, + transforms=[ + dict( + type='mmseg.RandomCrop', + crop_size=crop_size, + cat_max_ratio=0.75), + ]), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), diff --git a/configs/_base_/datasets/pascal_context_59.py b/configs/_base_/datasets/pascal_context_59.py index 37585aba..4e1865f0 100644 --- a/configs/_base_/datasets/pascal_context_59.py +++ b/configs/_base_/datasets/pascal_context_59.py @@ -11,7 +11,20 @@ train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), - dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict( + type='TransformBroadcaster', + mapping={ + 'img': ['img', 'gt_semantic_seg'], + 'img_shape': [..., 'img_shape'] + }, + auto_remap=True, + share_random_params=True, + transforms=[ + dict( + type='mmseg.RandomCrop', + crop_size=crop_size, + cat_max_ratio=0.75), + ]), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), diff --git a/configs/_base_/datasets/pascal_voc12.py b/configs/_base_/datasets/pascal_voc12.py index ba1d42d0..aa864390 100644 --- a/configs/_base_/datasets/pascal_voc12.py +++ b/configs/_base_/datasets/pascal_voc12.py @@ -8,7 +8,20 @@ train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), - dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict( + type='TransformBroadcaster', + mapping={ + 'img': ['img', 'gt_semantic_seg'], + 'img_shape': [..., 'img_shape'] + }, + auto_remap=True, + share_random_params=True, + transforms=[ + dict( + type='mmseg.RandomCrop', + crop_size=crop_size, + cat_max_ratio=0.75), + ]), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), diff --git a/configs/_base_/datasets/potsdam.py b/configs/_base_/datasets/potsdam.py index f74c4a56..0690578c 100644 --- a/configs/_base_/datasets/potsdam.py +++ b/configs/_base_/datasets/potsdam.py @@ -8,7 +8,20 @@ train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='Resize', img_scale=(512, 512), ratio_range=(0.5, 2.0)), - dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict( + type='TransformBroadcaster', + mapping={ + 'img': ['img', 'gt_semantic_seg'], + 'img_shape': [..., 'img_shape'] + }, + auto_remap=True, + share_random_params=True, + transforms=[ + dict( + type='mmseg.RandomCrop', + crop_size=crop_size, + cat_max_ratio=0.75), + ]), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), diff --git a/configs/_base_/datasets/stare.py b/configs/_base_/datasets/stare.py index 3f71b254..afba5127 100644 --- a/configs/_base_/datasets/stare.py +++ b/configs/_base_/datasets/stare.py @@ -9,7 +9,20 @@ train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), - dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict( + type='TransformBroadcaster', + mapping={ + 'img': ['img', 'gt_semantic_seg'], + 'img_shape': [..., 'img_shape'] + }, + auto_remap=True, + share_random_params=True, + transforms=[ + dict( + type='mmseg.RandomCrop', + crop_size=crop_size, + cat_max_ratio=0.75), + ]), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), diff --git a/configs/_base_/datasets/vaihingen.py b/configs/_base_/datasets/vaihingen.py index c0df282c..03987c00 100644 --- a/configs/_base_/datasets/vaihingen.py +++ b/configs/_base_/datasets/vaihingen.py @@ -8,7 +8,20 @@ train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='Resize', img_scale=(512, 512), ratio_range=(0.5, 2.0)), - dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict( + type='TransformBroadcaster', + mapping={ + 'img': ['img', 'gt_semantic_seg'], + 'img_shape': [..., 'img_shape'] + }, + auto_remap=True, + share_random_params=True, + transforms=[ + dict( + type='mmseg.RandomCrop', + crop_size=crop_size, + cat_max_ratio=0.75), + ]), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), diff --git a/mmseg/datasets/pipelines/transforms.py b/mmseg/datasets/pipelines/transforms.py index e65f9857..bfb776b4 100644 --- a/mmseg/datasets/pipelines/transforms.py +++ b/mmseg/datasets/pipelines/transforms.py @@ -1,8 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +from typing import Sequence, Tuple, Union import mmcv import numpy as np +from mmcv.transforms.base import BaseTransform +from mmcv.transforms.utils import cache_randomness from mmcv.utils import deprecated_api_warning, is_tuple_of from numpy import random @@ -581,51 +584,79 @@ class CLAHE(object): @TRANSFORMS.register_module() -class RandomCrop(object): +class RandomCrop(BaseTransform): """Random crop the image & seg. + Required Keys: + + - img + - gt_semantic_seg + + Modified Keys: + + - img + - img_shape + - gt_semantic_seg + + Args: - crop_size (tuple): Expected size after cropping, (h, w). + crop_size (Union[int, Tuple[int, int]]): Expected size after cropping + with the format of (h, w). If set to an integer, then cropping + width and height are equal to this integer. cat_max_ratio (float): The maximum ratio that single category could occupy. + ignore_index (int): The label index to be ignored. Default: 255 """ - def __init__(self, crop_size, cat_max_ratio=1., ignore_index=255): + def __init__(self, + crop_size: Union[int, Tuple[int, int]], + cat_max_ratio: float = 1., + ignore_index: int = 255): + super().__init__() + assert isinstance(crop_size, int) or ( + isinstance(crop_size, tuple) and len(crop_size) == 2 + ), 'The expected crop_size is an integer, or a tuple containing two ' + 'intergers' + + if isinstance(crop_size, int): + crop_size = (crop_size, crop_size) assert crop_size[0] > 0 and crop_size[1] > 0 self.crop_size = crop_size self.cat_max_ratio = cat_max_ratio self.ignore_index = ignore_index - def get_crop_bbox(self, img): - """Randomly get a crop bounding box.""" - margin_h = max(img.shape[0] - self.crop_size[0], 0) - margin_w = max(img.shape[1] - self.crop_size[1], 0) - offset_h = np.random.randint(0, margin_h + 1) - offset_w = np.random.randint(0, margin_w + 1) - crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0] - crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1] - - return crop_y1, crop_y2, crop_x1, crop_x2 - - def crop(self, img, crop_bbox): - """Crop from ``img``""" - crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox - img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] - return img - - def __call__(self, results): - """Call function to randomly crop images, semantic segmentation maps. + @cache_randomness + def crop_bbox(self, results: dict) -> tuple: + """get a crop bounding box. Args: results (dict): Result dict from loading pipeline. Returns: - dict: Randomly cropped results, 'img_shape' key in result dict is - updated according to crop size. + tuple: Coordinates of the cropped image. """ + def generate_crop_bbox(img: np.ndarray) -> tuple: + """Randomly get a crop bounding box. + + Args: + img (np.ndarray): Original input image. + + Returns: + tuple: Coordinates of the cropped image. + """ + + margin_h = max(img.shape[0] - self.crop_size[0], 0) + margin_w = max(img.shape[1] - self.crop_size[1], 0) + offset_h = np.random.randint(0, margin_h + 1) + offset_w = np.random.randint(0, margin_w + 1) + crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0] + crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1] + + return crop_y1, crop_y2, crop_x1, crop_x2 + img = results['img'] - crop_bbox = self.get_crop_bbox(img) + crop_bbox = generate_crop_bbox(img) if self.cat_max_ratio < 1.: # Repeat 10 times for _ in range(10): @@ -635,18 +666,45 @@ class RandomCrop(object): if len(cnt) > 1 and np.max(cnt) / np.sum( cnt) < self.cat_max_ratio: break - crop_bbox = self.get_crop_bbox(img) + crop_bbox = generate_crop_bbox(img) + + return crop_bbox + + def crop(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray: + """Crop from ``img`` + + Args: + img (np.ndarray): Original input image. + crop_bbox (tuple): Coordinates of the cropped image. + + Returns: + np.ndarray: The cropped image. + """ + + crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox + img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] + return img + + def transform(self, results: dict) -> dict: + """Transform function to randomly crop images, semantic segmentation + maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly cropped results, 'img_shape' key in result dict is + updated according to crop size. + """ + + img = results['img'] + crop_bbox = self.crop_bbox(results) # crop the image img = self.crop(img, crop_bbox) img_shape = img.shape results['img'] = img results['img_shape'] = img_shape - - # crop semantic seg - for key in results.get('seg_fields', []): - results[key] = self.crop(results[key], crop_bbox) - return results def __repr__(self): @@ -858,7 +916,7 @@ class SegRescale(object): @TRANSFORMS.register_module() -class PhotoMetricDistortion(object): +class PhotoMetricDistortion(BaseTransform): """Apply photometric distortion to image sequentially, every transformation is applied with a probability of 0.5. The position of random contrast is in second or second to last. @@ -871,6 +929,14 @@ class PhotoMetricDistortion(object): 6. convert color from HSV to BGR 7. random contrast (mode 1) + Required Keys: + + - img + + Modified Keys: + + - img + Args: brightness_delta (int): delta of brightness. contrast_range (tuple): range of contrast. @@ -879,23 +945,45 @@ class PhotoMetricDistortion(object): """ def __init__(self, - brightness_delta=32, - contrast_range=(0.5, 1.5), - saturation_range=(0.5, 1.5), - hue_delta=18): + brightness_delta: int = 32, + contrast_range: Sequence[float] = (0.5, 1.5), + saturation_range: Sequence[float] = (0.5, 1.5), + hue_delta: int = 18): self.brightness_delta = brightness_delta self.contrast_lower, self.contrast_upper = contrast_range self.saturation_lower, self.saturation_upper = saturation_range self.hue_delta = hue_delta - def convert(self, img, alpha=1, beta=0): - """Multiple with alpha and add beat with clip.""" + def convert(self, + img: np.ndarray, + alpha: int = 1, + beta: int = 0) -> np.ndarray: + """Multiple with alpha and add beat with clip. + + Args: + img (np.ndarray): The input image. + alpha (int): Image weights, change the contrast/saturation + of the image. Default: 1 + beta (int): Image bias, change the brightness of the + image. Default: 0 + + Returns: + np.ndarray: The transformed image. + """ + img = img.astype(np.float32) * alpha + beta img = np.clip(img, 0, 255) return img.astype(np.uint8) - def brightness(self, img): - """Brightness distortion.""" + def brightness(self, img: np.ndarray) -> np.ndarray: + """Brightness distortion. + + Args: + img (np.ndarray): The input image. + Returns: + np.ndarray: Image after brightness change. + """ + if random.randint(2): return self.convert( img, @@ -903,16 +991,30 @@ class PhotoMetricDistortion(object): self.brightness_delta)) return img - def contrast(self, img): - """Contrast distortion.""" + def contrast(self, img: np.ndarray) -> np.ndarray: + """Contrast distortion. + + Args: + img (np.ndarray): The input image. + Returns: + np.ndarray: Image after contrast change. + """ + if random.randint(2): return self.convert( img, alpha=random.uniform(self.contrast_lower, self.contrast_upper)) return img - def saturation(self, img): - """Saturation distortion.""" + def saturation(self, img: np.ndarray) -> np.ndarray: + """Saturation distortion. + + Args: + img (np.ndarray): The input image. + Returns: + np.ndarray: Image after saturation change. + """ + if random.randint(2): img = mmcv.bgr2hsv(img) img[:, :, 1] = self.convert( @@ -922,8 +1024,15 @@ class PhotoMetricDistortion(object): img = mmcv.hsv2bgr(img) return img - def hue(self, img): - """Hue distortion.""" + def hue(self, img: np.ndarray) -> np.ndarray: + """Hue distortion. + + Args: + img (np.ndarray): The input image. + Returns: + np.ndarray: Image after hue change. + """ + if random.randint(2): img = mmcv.bgr2hsv(img) img[:, :, @@ -932,8 +1041,8 @@ class PhotoMetricDistortion(object): img = mmcv.hsv2bgr(img) return img - def __call__(self, results): - """Call function to perform photometric distortion on images. + def transform(self, results: dict) -> dict: + """Transform function to perform photometric distortion on images. Args: results (dict): Result dict from loading pipeline. diff --git a/tests/test_datasets/test_pipelines/test_transforms.py b/tests/test_datasets/test_pipelines/test_transforms.py new file mode 100644 index 00000000..03da1266 --- /dev/null +++ b/tests/test_datasets/test_pipelines/test_transforms.py @@ -0,0 +1,64 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +import mmcv +import numpy as np +import pytest +from mmcv.transforms.wrappers import TransformBroadcaster +from PIL import Image + +from mmseg.datasets.pipelines import PhotoMetricDistortion, RandomCrop + + +def test_random_crop(): + # test assertion for invalid random crop + with pytest.raises(AssertionError): + RandomCrop(crop_size=(-1, 0)) + + results = dict() + img = mmcv.imread(osp.join('tests/data/color.jpg'), 'color') + seg = np.array(Image.open(osp.join('tests/data/seg.png'))) + results['img'] = img + results['gt_semantic_seg'] = seg + results['seg_fields'] = ['gt_semantic_seg'] + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + # Set initial values for default meta_keys + results['pad_shape'] = img.shape + results['scale_factor'] = 1.0 + + h, w, _ = img.shape + pipeline = TransformBroadcaster( + transforms=[RandomCrop(crop_size=(h - 20, w - 20))], + mapping={ + 'img': ['img', 'gt_semantic_seg'], + 'img_shape': [..., 'img_shape'] + }, + auto_remap=True, + share_random_params=True) + results = pipeline(results) + assert results['img'].shape[:2] == (h - 20, w - 20) + assert results['img_shape'][:2] == (h - 20, w - 20) + assert results['gt_semantic_seg'].shape[:2] == (h - 20, w - 20) + + +def test_photo_metric_distortion(): + + results = dict() + img = mmcv.imread(osp.join('tests/data/color.jpg'), 'color') + seg = np.array(Image.open(osp.join('tests/data/seg.png'))) + results['img'] = img + results['gt_semantic_seg'] = seg + results['seg_fields'] = ['gt_semantic_seg'] + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + # Set initial values for default meta_keys + results['pad_shape'] = img.shape + results['scale_factor'] = 1.0 + + pipeline = PhotoMetricDistortion() + results = pipeline(results) + + assert not ((results['img'] == img).all()) + assert (results['gt_semantic_seg'] == seg).all() + assert results['img_shape'] == img.shape -- GitLab