From f6d329106426417bcf559a31fdfd820677bff41e Mon Sep 17 00:00:00 2001
From: "linfangjian.vendor" <linfangjian@pjlab.org.cn>
Date: Wed, 25 May 2022 01:58:37 +0000
Subject: [PATCH] [Refactor] Refactor RandomCrop and PhotoMetricDistortion

---
 configs/_base_/datasets/ade20k.py             |  15 +-
 configs/_base_/datasets/ade20k_640x640.py     |  15 +-
 configs/_base_/datasets/chase_db1.py          |  15 +-
 configs/_base_/datasets/cityscapes.py         |  15 +-
 .../_base_/datasets/cityscapes_1024x1024.py   |  15 +-
 configs/_base_/datasets/cityscapes_768x768.py |  15 +-
 configs/_base_/datasets/cityscapes_769x769.py |  15 +-
 configs/_base_/datasets/cityscapes_832x832.py |  15 +-
 configs/_base_/datasets/coco-stuff10k.py      |  15 +-
 configs/_base_/datasets/coco-stuff164k.py     |  15 +-
 configs/_base_/datasets/drive.py              |  15 +-
 configs/_base_/datasets/hrf.py                |  15 +-
 configs/_base_/datasets/isaid.py              |  15 +-
 configs/_base_/datasets/loveda.py             |  15 +-
 configs/_base_/datasets/pascal_context.py     |  15 +-
 configs/_base_/datasets/pascal_context_59.py  |  15 +-
 configs/_base_/datasets/pascal_voc12.py       |  15 +-
 configs/_base_/datasets/potsdam.py            |  15 +-
 configs/_base_/datasets/stare.py              |  15 +-
 configs/_base_/datasets/vaihingen.py          |  15 +-
 mmseg/datasets/pipelines/transforms.py        | 205 ++++++++++++++----
 .../test_pipelines/test_transforms.py         |  64 ++++++
 22 files changed, 501 insertions(+), 68 deletions(-)
 create mode 100644 tests/test_datasets/test_pipelines/test_transforms.py

diff --git a/configs/_base_/datasets/ade20k.py b/configs/_base_/datasets/ade20k.py
index efc8b4bb..c7742285 100644
--- a/configs/_base_/datasets/ade20k.py
+++ b/configs/_base_/datasets/ade20k.py
@@ -8,7 +8,20 @@ train_pipeline = [
     dict(type='LoadImageFromFile'),
     dict(type='LoadAnnotations', reduce_zero_label=True),
     dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
-    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(
+        type='TransformBroadcaster',
+        mapping={
+            'img': ['img', 'gt_semantic_seg'],
+            'img_shape': [..., 'img_shape']
+        },
+        auto_remap=True,
+        share_random_params=True,
+        transforms=[
+            dict(
+                type='mmseg.RandomCrop',
+                crop_size=crop_size,
+                cat_max_ratio=0.75),
+        ]),
     dict(type='RandomFlip', prob=0.5),
     dict(type='PhotoMetricDistortion'),
     dict(type='Normalize', **img_norm_cfg),
diff --git a/configs/_base_/datasets/ade20k_640x640.py b/configs/_base_/datasets/ade20k_640x640.py
index 14a4bb09..3907f6fe 100644
--- a/configs/_base_/datasets/ade20k_640x640.py
+++ b/configs/_base_/datasets/ade20k_640x640.py
@@ -8,7 +8,20 @@ train_pipeline = [
     dict(type='LoadImageFromFile'),
     dict(type='LoadAnnotations', reduce_zero_label=True),
     dict(type='Resize', img_scale=(2560, 640), ratio_range=(0.5, 2.0)),
-    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(
+        type='TransformBroadcaster',
+        mapping={
+            'img': ['img', 'gt_semantic_seg'],
+            'img_shape': [..., 'img_shape']
+        },
+        auto_remap=True,
+        share_random_params=True,
+        transforms=[
+            dict(
+                type='mmseg.RandomCrop',
+                crop_size=crop_size,
+                cat_max_ratio=0.75),
+        ]),
     dict(type='RandomFlip', prob=0.5),
     dict(type='PhotoMetricDistortion'),
     dict(type='Normalize', **img_norm_cfg),
diff --git a/configs/_base_/datasets/chase_db1.py b/configs/_base_/datasets/chase_db1.py
index 298594ea..f8eb4fdb 100644
--- a/configs/_base_/datasets/chase_db1.py
+++ b/configs/_base_/datasets/chase_db1.py
@@ -9,7 +9,20 @@ train_pipeline = [
     dict(type='LoadImageFromFile'),
     dict(type='LoadAnnotations'),
     dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
-    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(
+        type='TransformBroadcaster',
+        mapping={
+            'img': ['img', 'gt_semantic_seg'],
+            'img_shape': [..., 'img_shape']
+        },
+        auto_remap=True,
+        share_random_params=True,
+        transforms=[
+            dict(
+                type='mmseg.RandomCrop',
+                crop_size=crop_size,
+                cat_max_ratio=0.75),
+        ]),
     dict(type='RandomFlip', prob=0.5),
     dict(type='PhotoMetricDistortion'),
     dict(type='Normalize', **img_norm_cfg),
diff --git a/configs/_base_/datasets/cityscapes.py b/configs/_base_/datasets/cityscapes.py
index f21867c6..4a645e1b 100644
--- a/configs/_base_/datasets/cityscapes.py
+++ b/configs/_base_/datasets/cityscapes.py
@@ -8,7 +8,20 @@ train_pipeline = [
     dict(type='LoadImageFromFile'),
     dict(type='LoadAnnotations'),
     dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
-    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(
+        type='TransformBroadcaster',
+        mapping={
+            'img': ['img', 'gt_semantic_seg'],
+            'img_shape': [..., 'img_shape']
+        },
+        auto_remap=True,
+        share_random_params=True,
+        transforms=[
+            dict(
+                type='mmseg.RandomCrop',
+                crop_size=crop_size,
+                cat_max_ratio=0.75),
+        ]),
     dict(type='RandomFlip', prob=0.5),
     dict(type='PhotoMetricDistortion'),
     dict(type='Normalize', **img_norm_cfg),
diff --git a/configs/_base_/datasets/cityscapes_1024x1024.py b/configs/_base_/datasets/cityscapes_1024x1024.py
index f98d9297..57d09289 100644
--- a/configs/_base_/datasets/cityscapes_1024x1024.py
+++ b/configs/_base_/datasets/cityscapes_1024x1024.py
@@ -6,7 +6,20 @@ train_pipeline = [
     dict(type='LoadImageFromFile'),
     dict(type='LoadAnnotations'),
     dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
-    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(
+        type='TransformBroadcaster',
+        mapping={
+            'img': ['img', 'gt_semantic_seg'],
+            'img_shape': [..., 'img_shape']
+        },
+        auto_remap=True,
+        share_random_params=True,
+        transforms=[
+            dict(
+                type='mmseg.RandomCrop',
+                crop_size=crop_size,
+                cat_max_ratio=0.75),
+        ]),
     dict(type='RandomFlip', prob=0.5),
     dict(type='PhotoMetricDistortion'),
     dict(type='Normalize', **img_norm_cfg),
diff --git a/configs/_base_/datasets/cityscapes_768x768.py b/configs/_base_/datasets/cityscapes_768x768.py
index fde9d7c7..8735ef55 100644
--- a/configs/_base_/datasets/cityscapes_768x768.py
+++ b/configs/_base_/datasets/cityscapes_768x768.py
@@ -6,7 +6,20 @@ train_pipeline = [
     dict(type='LoadImageFromFile'),
     dict(type='LoadAnnotations'),
     dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)),
-    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(
+        type='TransformBroadcaster',
+        mapping={
+            'img': ['img', 'gt_semantic_seg'],
+            'img_shape': [..., 'img_shape']
+        },
+        auto_remap=True,
+        share_random_params=True,
+        transforms=[
+            dict(
+                type='mmseg.RandomCrop',
+                crop_size=crop_size,
+                cat_max_ratio=0.75),
+        ]),
     dict(type='RandomFlip', prob=0.5),
     dict(type='PhotoMetricDistortion'),
     dict(type='Normalize', **img_norm_cfg),
diff --git a/configs/_base_/datasets/cityscapes_769x769.py b/configs/_base_/datasets/cityscapes_769x769.py
index 336c7b25..d04ac0a8 100644
--- a/configs/_base_/datasets/cityscapes_769x769.py
+++ b/configs/_base_/datasets/cityscapes_769x769.py
@@ -6,7 +6,20 @@ train_pipeline = [
     dict(type='LoadImageFromFile'),
     dict(type='LoadAnnotations'),
     dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)),
-    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(
+        type='TransformBroadcaster',
+        mapping={
+            'img': ['img', 'gt_semantic_seg'],
+            'img_shape': [..., 'img_shape']
+        },
+        auto_remap=True,
+        share_random_params=True,
+        transforms=[
+            dict(
+                type='mmseg.RandomCrop',
+                crop_size=crop_size,
+                cat_max_ratio=0.75),
+        ]),
     dict(type='RandomFlip', prob=0.5),
     dict(type='PhotoMetricDistortion'),
     dict(type='Normalize', **img_norm_cfg),
diff --git a/configs/_base_/datasets/cityscapes_832x832.py b/configs/_base_/datasets/cityscapes_832x832.py
index b9325cc0..d65c16f5 100644
--- a/configs/_base_/datasets/cityscapes_832x832.py
+++ b/configs/_base_/datasets/cityscapes_832x832.py
@@ -6,7 +6,20 @@ train_pipeline = [
     dict(type='LoadImageFromFile'),
     dict(type='LoadAnnotations'),
     dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
-    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(
+        type='TransformBroadcaster',
+        mapping={
+            'img': ['img', 'gt_semantic_seg'],
+            'img_shape': [..., 'img_shape']
+        },
+        auto_remap=True,
+        share_random_params=True,
+        transforms=[
+            dict(
+                type='mmseg.RandomCrop',
+                crop_size=crop_size,
+                cat_max_ratio=0.75),
+        ]),
     dict(type='RandomFlip', prob=0.5),
     dict(type='PhotoMetricDistortion'),
     dict(type='Normalize', **img_norm_cfg),
diff --git a/configs/_base_/datasets/coco-stuff10k.py b/configs/_base_/datasets/coco-stuff10k.py
index ec049692..ceec06dc 100644
--- a/configs/_base_/datasets/coco-stuff10k.py
+++ b/configs/_base_/datasets/coco-stuff10k.py
@@ -8,7 +8,20 @@ train_pipeline = [
     dict(type='LoadImageFromFile'),
     dict(type='LoadAnnotations', reduce_zero_label=True),
     dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
-    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(
+        type='TransformBroadcaster',
+        mapping={
+            'img': ['img', 'gt_semantic_seg'],
+            'img_shape': [..., 'img_shape']
+        },
+        auto_remap=True,
+        share_random_params=True,
+        transforms=[
+            dict(
+                type='mmseg.RandomCrop',
+                crop_size=crop_size,
+                cat_max_ratio=0.75),
+        ]),
     dict(type='RandomFlip', prob=0.5),
     dict(type='PhotoMetricDistortion'),
     dict(type='Normalize', **img_norm_cfg),
diff --git a/configs/_base_/datasets/coco-stuff164k.py b/configs/_base_/datasets/coco-stuff164k.py
index a6a38f2a..29a33894 100644
--- a/configs/_base_/datasets/coco-stuff164k.py
+++ b/configs/_base_/datasets/coco-stuff164k.py
@@ -8,7 +8,20 @@ train_pipeline = [
     dict(type='LoadImageFromFile'),
     dict(type='LoadAnnotations'),
     dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
-    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(
+        type='TransformBroadcaster',
+        mapping={
+            'img': ['img', 'gt_semantic_seg'],
+            'img_shape': [..., 'img_shape']
+        },
+        auto_remap=True,
+        share_random_params=True,
+        transforms=[
+            dict(
+                type='mmseg.RandomCrop',
+                crop_size=crop_size,
+                cat_max_ratio=0.75),
+        ]),
     dict(type='RandomFlip', prob=0.5),
     dict(type='PhotoMetricDistortion'),
     dict(type='Normalize', **img_norm_cfg),
diff --git a/configs/_base_/datasets/drive.py b/configs/_base_/datasets/drive.py
index 06e8ff60..6b00bc75 100644
--- a/configs/_base_/datasets/drive.py
+++ b/configs/_base_/datasets/drive.py
@@ -9,7 +9,20 @@ train_pipeline = [
     dict(type='LoadImageFromFile'),
     dict(type='LoadAnnotations'),
     dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
-    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(
+        type='TransformBroadcaster',
+        mapping={
+            'img': ['img', 'gt_semantic_seg'],
+            'img_shape': [..., 'img_shape']
+        },
+        auto_remap=True,
+        share_random_params=True,
+        transforms=[
+            dict(
+                type='mmseg.RandomCrop',
+                crop_size=crop_size,
+                cat_max_ratio=0.75),
+        ]),
     dict(type='RandomFlip', prob=0.5),
     dict(type='PhotoMetricDistortion'),
     dict(type='Normalize', **img_norm_cfg),
diff --git a/configs/_base_/datasets/hrf.py b/configs/_base_/datasets/hrf.py
index 242d790e..2c1ad741 100644
--- a/configs/_base_/datasets/hrf.py
+++ b/configs/_base_/datasets/hrf.py
@@ -9,7 +9,20 @@ train_pipeline = [
     dict(type='LoadImageFromFile'),
     dict(type='LoadAnnotations'),
     dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
-    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(
+        type='TransformBroadcaster',
+        mapping={
+            'img': ['img', 'gt_semantic_seg'],
+            'img_shape': [..., 'img_shape']
+        },
+        auto_remap=True,
+        share_random_params=True,
+        transforms=[
+            dict(
+                type='mmseg.RandomCrop',
+                crop_size=crop_size,
+                cat_max_ratio=0.75),
+        ]),
     dict(type='RandomFlip', prob=0.5),
     dict(type='PhotoMetricDistortion'),
     dict(type='Normalize', **img_norm_cfg),
diff --git a/configs/_base_/datasets/isaid.py b/configs/_base_/datasets/isaid.py
index 8e4c26ab..29e731fb 100644
--- a/configs/_base_/datasets/isaid.py
+++ b/configs/_base_/datasets/isaid.py
@@ -16,7 +16,20 @@ train_pipeline = [
     dict(type='LoadImageFromFile'),
     dict(type='LoadAnnotations'),
     dict(type='Resize', img_scale=(896, 896), ratio_range=(0.5, 2.0)),
-    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(
+        type='TransformBroadcaster',
+        mapping={
+            'img': ['img', 'gt_semantic_seg'],
+            'img_shape': [..., 'img_shape']
+        },
+        auto_remap=True,
+        share_random_params=True,
+        transforms=[
+            dict(
+                type='mmseg.RandomCrop',
+                crop_size=crop_size,
+                cat_max_ratio=0.75),
+        ]),
     dict(type='RandomFlip', prob=0.5),
     dict(type='PhotoMetricDistortion'),
     dict(type='Normalize', **img_norm_cfg),
diff --git a/configs/_base_/datasets/loveda.py b/configs/_base_/datasets/loveda.py
index e5533569..bcdc4f15 100644
--- a/configs/_base_/datasets/loveda.py
+++ b/configs/_base_/datasets/loveda.py
@@ -8,7 +8,20 @@ train_pipeline = [
     dict(type='LoadImageFromFile'),
     dict(type='LoadAnnotations', reduce_zero_label=True),
     dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
-    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(
+        type='TransformBroadcaster',
+        mapping={
+            'img': ['img', 'gt_semantic_seg'],
+            'img_shape': [..., 'img_shape']
+        },
+        auto_remap=True,
+        share_random_params=True,
+        transforms=[
+            dict(
+                type='mmseg.RandomCrop',
+                crop_size=crop_size,
+                cat_max_ratio=0.75),
+        ]),
     dict(type='RandomFlip', prob=0.5),
     dict(type='PhotoMetricDistortion'),
     dict(type='Normalize', **img_norm_cfg),
diff --git a/configs/_base_/datasets/pascal_context.py b/configs/_base_/datasets/pascal_context.py
index ff65bad1..0f803a6b 100644
--- a/configs/_base_/datasets/pascal_context.py
+++ b/configs/_base_/datasets/pascal_context.py
@@ -11,7 +11,20 @@ train_pipeline = [
     dict(type='LoadImageFromFile'),
     dict(type='LoadAnnotations'),
     dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
-    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(
+        type='TransformBroadcaster',
+        mapping={
+            'img': ['img', 'gt_semantic_seg'],
+            'img_shape': [..., 'img_shape']
+        },
+        auto_remap=True,
+        share_random_params=True,
+        transforms=[
+            dict(
+                type='mmseg.RandomCrop',
+                crop_size=crop_size,
+                cat_max_ratio=0.75),
+        ]),
     dict(type='RandomFlip', prob=0.5),
     dict(type='PhotoMetricDistortion'),
     dict(type='Normalize', **img_norm_cfg),
diff --git a/configs/_base_/datasets/pascal_context_59.py b/configs/_base_/datasets/pascal_context_59.py
index 37585aba..4e1865f0 100644
--- a/configs/_base_/datasets/pascal_context_59.py
+++ b/configs/_base_/datasets/pascal_context_59.py
@@ -11,7 +11,20 @@ train_pipeline = [
     dict(type='LoadImageFromFile'),
     dict(type='LoadAnnotations', reduce_zero_label=True),
     dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
-    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(
+        type='TransformBroadcaster',
+        mapping={
+            'img': ['img', 'gt_semantic_seg'],
+            'img_shape': [..., 'img_shape']
+        },
+        auto_remap=True,
+        share_random_params=True,
+        transforms=[
+            dict(
+                type='mmseg.RandomCrop',
+                crop_size=crop_size,
+                cat_max_ratio=0.75),
+        ]),
     dict(type='RandomFlip', prob=0.5),
     dict(type='PhotoMetricDistortion'),
     dict(type='Normalize', **img_norm_cfg),
diff --git a/configs/_base_/datasets/pascal_voc12.py b/configs/_base_/datasets/pascal_voc12.py
index ba1d42d0..aa864390 100644
--- a/configs/_base_/datasets/pascal_voc12.py
+++ b/configs/_base_/datasets/pascal_voc12.py
@@ -8,7 +8,20 @@ train_pipeline = [
     dict(type='LoadImageFromFile'),
     dict(type='LoadAnnotations'),
     dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
-    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(
+        type='TransformBroadcaster',
+        mapping={
+            'img': ['img', 'gt_semantic_seg'],
+            'img_shape': [..., 'img_shape']
+        },
+        auto_remap=True,
+        share_random_params=True,
+        transforms=[
+            dict(
+                type='mmseg.RandomCrop',
+                crop_size=crop_size,
+                cat_max_ratio=0.75),
+        ]),
     dict(type='RandomFlip', prob=0.5),
     dict(type='PhotoMetricDistortion'),
     dict(type='Normalize', **img_norm_cfg),
diff --git a/configs/_base_/datasets/potsdam.py b/configs/_base_/datasets/potsdam.py
index f74c4a56..0690578c 100644
--- a/configs/_base_/datasets/potsdam.py
+++ b/configs/_base_/datasets/potsdam.py
@@ -8,7 +8,20 @@ train_pipeline = [
     dict(type='LoadImageFromFile'),
     dict(type='LoadAnnotations', reduce_zero_label=True),
     dict(type='Resize', img_scale=(512, 512), ratio_range=(0.5, 2.0)),
-    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(
+        type='TransformBroadcaster',
+        mapping={
+            'img': ['img', 'gt_semantic_seg'],
+            'img_shape': [..., 'img_shape']
+        },
+        auto_remap=True,
+        share_random_params=True,
+        transforms=[
+            dict(
+                type='mmseg.RandomCrop',
+                crop_size=crop_size,
+                cat_max_ratio=0.75),
+        ]),
     dict(type='RandomFlip', prob=0.5),
     dict(type='PhotoMetricDistortion'),
     dict(type='Normalize', **img_norm_cfg),
diff --git a/configs/_base_/datasets/stare.py b/configs/_base_/datasets/stare.py
index 3f71b254..afba5127 100644
--- a/configs/_base_/datasets/stare.py
+++ b/configs/_base_/datasets/stare.py
@@ -9,7 +9,20 @@ train_pipeline = [
     dict(type='LoadImageFromFile'),
     dict(type='LoadAnnotations'),
     dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
-    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(
+        type='TransformBroadcaster',
+        mapping={
+            'img': ['img', 'gt_semantic_seg'],
+            'img_shape': [..., 'img_shape']
+        },
+        auto_remap=True,
+        share_random_params=True,
+        transforms=[
+            dict(
+                type='mmseg.RandomCrop',
+                crop_size=crop_size,
+                cat_max_ratio=0.75),
+        ]),
     dict(type='RandomFlip', prob=0.5),
     dict(type='PhotoMetricDistortion'),
     dict(type='Normalize', **img_norm_cfg),
diff --git a/configs/_base_/datasets/vaihingen.py b/configs/_base_/datasets/vaihingen.py
index c0df282c..03987c00 100644
--- a/configs/_base_/datasets/vaihingen.py
+++ b/configs/_base_/datasets/vaihingen.py
@@ -8,7 +8,20 @@ train_pipeline = [
     dict(type='LoadImageFromFile'),
     dict(type='LoadAnnotations', reduce_zero_label=True),
     dict(type='Resize', img_scale=(512, 512), ratio_range=(0.5, 2.0)),
-    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(
+        type='TransformBroadcaster',
+        mapping={
+            'img': ['img', 'gt_semantic_seg'],
+            'img_shape': [..., 'img_shape']
+        },
+        auto_remap=True,
+        share_random_params=True,
+        transforms=[
+            dict(
+                type='mmseg.RandomCrop',
+                crop_size=crop_size,
+                cat_max_ratio=0.75),
+        ]),
     dict(type='RandomFlip', prob=0.5),
     dict(type='PhotoMetricDistortion'),
     dict(type='Normalize', **img_norm_cfg),
diff --git a/mmseg/datasets/pipelines/transforms.py b/mmseg/datasets/pipelines/transforms.py
index e65f9857..bfb776b4 100644
--- a/mmseg/datasets/pipelines/transforms.py
+++ b/mmseg/datasets/pipelines/transforms.py
@@ -1,8 +1,11 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import copy
+from typing import Sequence, Tuple, Union
 
 import mmcv
 import numpy as np
+from mmcv.transforms.base import BaseTransform
+from mmcv.transforms.utils import cache_randomness
 from mmcv.utils import deprecated_api_warning, is_tuple_of
 from numpy import random
 
@@ -581,51 +584,79 @@ class CLAHE(object):
 
 
 @TRANSFORMS.register_module()
-class RandomCrop(object):
+class RandomCrop(BaseTransform):
     """Random crop the image & seg.
 
+    Required Keys:
+
+    - img
+    - gt_semantic_seg
+
+    Modified Keys:
+
+    - img
+    - img_shape
+    - gt_semantic_seg
+
+
     Args:
-        crop_size (tuple): Expected size after cropping, (h, w).
+        crop_size (Union[int, Tuple[int, int]]):  Expected size after cropping
+            with the format of (h, w). If set to an integer, then cropping
+            width and height are equal to this integer.
         cat_max_ratio (float): The maximum ratio that single category could
             occupy.
+        ignore_index (int): The label index to be ignored. Default: 255
     """
 
-    def __init__(self, crop_size, cat_max_ratio=1., ignore_index=255):
+    def __init__(self,
+                 crop_size: Union[int, Tuple[int, int]],
+                 cat_max_ratio: float = 1.,
+                 ignore_index: int = 255):
+        super().__init__()
+        assert isinstance(crop_size, int) or (
+            isinstance(crop_size, tuple) and len(crop_size) == 2
+        ), 'The expected crop_size is an integer, or a tuple containing two '
+        'intergers'
+
+        if isinstance(crop_size, int):
+            crop_size = (crop_size, crop_size)
         assert crop_size[0] > 0 and crop_size[1] > 0
         self.crop_size = crop_size
         self.cat_max_ratio = cat_max_ratio
         self.ignore_index = ignore_index
 
-    def get_crop_bbox(self, img):
-        """Randomly get a crop bounding box."""
-        margin_h = max(img.shape[0] - self.crop_size[0], 0)
-        margin_w = max(img.shape[1] - self.crop_size[1], 0)
-        offset_h = np.random.randint(0, margin_h + 1)
-        offset_w = np.random.randint(0, margin_w + 1)
-        crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]
-        crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]
-
-        return crop_y1, crop_y2, crop_x1, crop_x2
-
-    def crop(self, img, crop_bbox):
-        """Crop from ``img``"""
-        crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
-        img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
-        return img
-
-    def __call__(self, results):
-        """Call function to randomly crop images, semantic segmentation maps.
+    @cache_randomness
+    def crop_bbox(self, results: dict) -> tuple:
+        """get a crop bounding box.
 
         Args:
             results (dict): Result dict from loading pipeline.
 
         Returns:
-            dict: Randomly cropped results, 'img_shape' key in result dict is
-                updated according to crop size.
+            tuple: Coordinates of the cropped image.
         """
 
+        def generate_crop_bbox(img: np.ndarray) -> tuple:
+            """Randomly get a crop bounding box.
+
+            Args:
+                img (np.ndarray): Original input image.
+
+            Returns:
+                tuple: Coordinates of the cropped image.
+            """
+
+            margin_h = max(img.shape[0] - self.crop_size[0], 0)
+            margin_w = max(img.shape[1] - self.crop_size[1], 0)
+            offset_h = np.random.randint(0, margin_h + 1)
+            offset_w = np.random.randint(0, margin_w + 1)
+            crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]
+            crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]
+
+            return crop_y1, crop_y2, crop_x1, crop_x2
+
         img = results['img']
-        crop_bbox = self.get_crop_bbox(img)
+        crop_bbox = generate_crop_bbox(img)
         if self.cat_max_ratio < 1.:
             # Repeat 10 times
             for _ in range(10):
@@ -635,18 +666,45 @@ class RandomCrop(object):
                 if len(cnt) > 1 and np.max(cnt) / np.sum(
                         cnt) < self.cat_max_ratio:
                     break
-                crop_bbox = self.get_crop_bbox(img)
+                crop_bbox = generate_crop_bbox(img)
+
+        return crop_bbox
+
+    def crop(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray:
+        """Crop from ``img``
+
+        Args:
+            img (np.ndarray): Original input image.
+            crop_bbox (tuple): Coordinates of the cropped image.
+
+        Returns:
+            np.ndarray: The cropped image.
+        """
+
+        crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
+        img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
+        return img
+
+    def transform(self, results: dict) -> dict:
+        """Transform function to randomly crop images, semantic segmentation
+        maps.
+
+        Args:
+            results (dict): Result dict from loading pipeline.
+
+        Returns:
+            dict: Randomly cropped results, 'img_shape' key in result dict is
+                updated according to crop size.
+        """
+
+        img = results['img']
+        crop_bbox = self.crop_bbox(results)
 
         # crop the image
         img = self.crop(img, crop_bbox)
         img_shape = img.shape
         results['img'] = img
         results['img_shape'] = img_shape
-
-        # crop semantic seg
-        for key in results.get('seg_fields', []):
-            results[key] = self.crop(results[key], crop_bbox)
-
         return results
 
     def __repr__(self):
@@ -858,7 +916,7 @@ class SegRescale(object):
 
 
 @TRANSFORMS.register_module()
-class PhotoMetricDistortion(object):
+class PhotoMetricDistortion(BaseTransform):
     """Apply photometric distortion to image sequentially, every transformation
     is applied with a probability of 0.5. The position of random contrast is in
     second or second to last.
@@ -871,6 +929,14 @@ class PhotoMetricDistortion(object):
     6. convert color from HSV to BGR
     7. random contrast (mode 1)
 
+    Required Keys:
+
+    - img
+
+    Modified Keys:
+
+    - img
+
     Args:
         brightness_delta (int): delta of brightness.
         contrast_range (tuple): range of contrast.
@@ -879,23 +945,45 @@ class PhotoMetricDistortion(object):
     """
 
     def __init__(self,
-                 brightness_delta=32,
-                 contrast_range=(0.5, 1.5),
-                 saturation_range=(0.5, 1.5),
-                 hue_delta=18):
+                 brightness_delta: int = 32,
+                 contrast_range: Sequence[float] = (0.5, 1.5),
+                 saturation_range: Sequence[float] = (0.5, 1.5),
+                 hue_delta: int = 18):
         self.brightness_delta = brightness_delta
         self.contrast_lower, self.contrast_upper = contrast_range
         self.saturation_lower, self.saturation_upper = saturation_range
         self.hue_delta = hue_delta
 
-    def convert(self, img, alpha=1, beta=0):
-        """Multiple with alpha and add beat with clip."""
+    def convert(self,
+                img: np.ndarray,
+                alpha: int = 1,
+                beta: int = 0) -> np.ndarray:
+        """Multiple with alpha and add beat with clip.
+
+        Args:
+            img (np.ndarray): The input image.
+            alpha (int): Image weights, change the contrast/saturation
+                of the image. Default: 1
+            beta (int): Image bias, change the brightness of the
+                image. Default: 0
+
+        Returns:
+            np.ndarray: The transformed image.
+        """
+
         img = img.astype(np.float32) * alpha + beta
         img = np.clip(img, 0, 255)
         return img.astype(np.uint8)
 
-    def brightness(self, img):
-        """Brightness distortion."""
+    def brightness(self, img: np.ndarray) -> np.ndarray:
+        """Brightness distortion.
+
+        Args:
+            img (np.ndarray): The input image.
+        Returns:
+            np.ndarray: Image after brightness change.
+        """
+
         if random.randint(2):
             return self.convert(
                 img,
@@ -903,16 +991,30 @@ class PhotoMetricDistortion(object):
                                     self.brightness_delta))
         return img
 
-    def contrast(self, img):
-        """Contrast distortion."""
+    def contrast(self, img: np.ndarray) -> np.ndarray:
+        """Contrast distortion.
+
+        Args:
+            img (np.ndarray): The input image.
+        Returns:
+            np.ndarray: Image after contrast change.
+        """
+
         if random.randint(2):
             return self.convert(
                 img,
                 alpha=random.uniform(self.contrast_lower, self.contrast_upper))
         return img
 
-    def saturation(self, img):
-        """Saturation distortion."""
+    def saturation(self, img: np.ndarray) -> np.ndarray:
+        """Saturation distortion.
+
+        Args:
+            img (np.ndarray): The input image.
+        Returns:
+            np.ndarray: Image after saturation change.
+        """
+
         if random.randint(2):
             img = mmcv.bgr2hsv(img)
             img[:, :, 1] = self.convert(
@@ -922,8 +1024,15 @@ class PhotoMetricDistortion(object):
             img = mmcv.hsv2bgr(img)
         return img
 
-    def hue(self, img):
-        """Hue distortion."""
+    def hue(self, img: np.ndarray) -> np.ndarray:
+        """Hue distortion.
+
+        Args:
+            img (np.ndarray): The input image.
+        Returns:
+            np.ndarray: Image after hue change.
+        """
+
         if random.randint(2):
             img = mmcv.bgr2hsv(img)
             img[:, :,
@@ -932,8 +1041,8 @@ class PhotoMetricDistortion(object):
             img = mmcv.hsv2bgr(img)
         return img
 
-    def __call__(self, results):
-        """Call function to perform photometric distortion on images.
+    def transform(self, results: dict) -> dict:
+        """Transform function to perform photometric distortion on images.
 
         Args:
             results (dict): Result dict from loading pipeline.
diff --git a/tests/test_datasets/test_pipelines/test_transforms.py b/tests/test_datasets/test_pipelines/test_transforms.py
new file mode 100644
index 00000000..03da1266
--- /dev/null
+++ b/tests/test_datasets/test_pipelines/test_transforms.py
@@ -0,0 +1,64 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+
+import mmcv
+import numpy as np
+import pytest
+from mmcv.transforms.wrappers import TransformBroadcaster
+from PIL import Image
+
+from mmseg.datasets.pipelines import PhotoMetricDistortion, RandomCrop
+
+
+def test_random_crop():
+    # test assertion for invalid random crop
+    with pytest.raises(AssertionError):
+        RandomCrop(crop_size=(-1, 0))
+
+    results = dict()
+    img = mmcv.imread(osp.join('tests/data/color.jpg'), 'color')
+    seg = np.array(Image.open(osp.join('tests/data/seg.png')))
+    results['img'] = img
+    results['gt_semantic_seg'] = seg
+    results['seg_fields'] = ['gt_semantic_seg']
+    results['img_shape'] = img.shape
+    results['ori_shape'] = img.shape
+    # Set initial values for default meta_keys
+    results['pad_shape'] = img.shape
+    results['scale_factor'] = 1.0
+
+    h, w, _ = img.shape
+    pipeline = TransformBroadcaster(
+        transforms=[RandomCrop(crop_size=(h - 20, w - 20))],
+        mapping={
+            'img': ['img', 'gt_semantic_seg'],
+            'img_shape': [..., 'img_shape']
+        },
+        auto_remap=True,
+        share_random_params=True)
+    results = pipeline(results)
+    assert results['img'].shape[:2] == (h - 20, w - 20)
+    assert results['img_shape'][:2] == (h - 20, w - 20)
+    assert results['gt_semantic_seg'].shape[:2] == (h - 20, w - 20)
+
+
+def test_photo_metric_distortion():
+
+    results = dict()
+    img = mmcv.imread(osp.join('tests/data/color.jpg'), 'color')
+    seg = np.array(Image.open(osp.join('tests/data/seg.png')))
+    results['img'] = img
+    results['gt_semantic_seg'] = seg
+    results['seg_fields'] = ['gt_semantic_seg']
+    results['img_shape'] = img.shape
+    results['ori_shape'] = img.shape
+    # Set initial values for default meta_keys
+    results['pad_shape'] = img.shape
+    results['scale_factor'] = 1.0
+
+    pipeline = PhotoMetricDistortion()
+    results = pipeline(results)
+
+    assert not ((results['img'] == img).all())
+    assert (results['gt_semantic_seg'] == seg).all()
+    assert results['img_shape'] == img.shape
-- 
GitLab