From f15a21a30dac01053b84b5374690304f6dceb76d Mon Sep 17 00:00:00 2001
From: MengzhangLI <mcmong@pku.edu.cn>
Date: Mon, 28 Mar 2022 23:50:39 +0800
Subject: [PATCH] [Enchance] Support random seed for distributed sampler
 (#1411)

* support random seed for distributed sampler

* move mmseg/utils/dist_util.py to mmseg/core/utils/dist_util.py

* move mmseg/utils/dist_util.py to mmseg/core/utils/dist_util.py

* change dist sampler

* change dist sampler

* fix docstring in sync_random_seed
---
 mmseg/core/utils/__init__.py                  |  6 +-
 mmseg/core/utils/dist_util.py                 | 46 ++++++++++++
 mmseg/datasets/builder.py                     |  6 +-
 mmseg/datasets/samplers/__init__.py           |  4 ++
 .../datasets/samplers/distributed_sampler.py  | 71 +++++++++++++++++++
 5 files changed, 130 insertions(+), 3 deletions(-)
 create mode 100644 mmseg/core/utils/dist_util.py
 create mode 100644 mmseg/datasets/samplers/__init__.py
 create mode 100644 mmseg/datasets/samplers/distributed_sampler.py

diff --git a/mmseg/core/utils/__init__.py b/mmseg/core/utils/__init__.py
index c8694b55..cb5a0c3f 100644
--- a/mmseg/core/utils/__init__.py
+++ b/mmseg/core/utils/__init__.py
@@ -1,6 +1,10 @@
 # Copyright (c) OpenMMLab. All rights reserved.
+from .dist_util import check_dist_init, sync_random_seed
 from .layer_decay_optimizer_constructor import \
     LearningRateDecayOptimizerConstructor
 from .misc import add_prefix
 
-__all__ = ['add_prefix', 'LearningRateDecayOptimizerConstructor']
+__all__ = [
+    'add_prefix', 'LearningRateDecayOptimizerConstructor', 'check_dist_init',
+    'sync_random_seed'
+]
diff --git a/mmseg/core/utils/dist_util.py b/mmseg/core/utils/dist_util.py
new file mode 100644
index 00000000..b3288519
--- /dev/null
+++ b/mmseg/core/utils/dist_util.py
@@ -0,0 +1,46 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+import torch.distributed as dist
+from mmcv.runner import get_dist_info
+
+
+def check_dist_init():
+    return dist.is_available() and dist.is_initialized()
+
+
+def sync_random_seed(seed=None, device='cuda'):
+    """Make sure different ranks share the same seed. All workers must call
+    this function, otherwise it will deadlock. This method is generally used in
+    `DistributedSampler`, because the seed should be identical across all
+    processes in the distributed group.
+
+    In distributed sampling, different ranks should sample non-overlapped
+    data in the dataset. Therefore, this function is used to make sure that
+    each rank shuffles the data indices in the same order based
+    on the same seed. Then different ranks could use different indices
+    to select non-overlapped data from the same data list.
+
+    Args:
+        seed (int, Optional): The seed. Default to None.
+        device (str): The device where the seed will be put on.
+            Default to 'cuda'.
+    Returns:
+        int: Seed to be used.
+    """
+
+    if seed is None:
+        seed = np.random.randint(2**31)
+    assert isinstance(seed, int)
+
+    rank, world_size = get_dist_info()
+
+    if world_size == 1:
+        return seed
+
+    if rank == 0:
+        random_num = torch.tensor(seed, dtype=torch.int32, device=device)
+    else:
+        random_num = torch.tensor(0, dtype=torch.int32, device=device)
+    dist.broadcast(random_num, src=0)
+    return random_num.item()
diff --git a/mmseg/datasets/builder.py b/mmseg/datasets/builder.py
index 21004e6e..4d852d36 100644
--- a/mmseg/datasets/builder.py
+++ b/mmseg/datasets/builder.py
@@ -9,7 +9,9 @@ import torch
 from mmcv.parallel import collate
 from mmcv.runner import get_dist_info
 from mmcv.utils import Registry, build_from_cfg, digit_version
-from torch.utils.data import DataLoader, DistributedSampler
+from torch.utils.data import DataLoader
+
+from .samplers import DistributedSampler
 
 if platform.system() != 'Windows':
     # https://github.com/pytorch/pytorch/issues/973
@@ -129,7 +131,7 @@ def build_dataloader(dataset,
     rank, world_size = get_dist_info()
     if dist:
         sampler = DistributedSampler(
-            dataset, world_size, rank, shuffle=shuffle)
+            dataset, world_size, rank, shuffle=shuffle, seed=seed)
         shuffle = False
         batch_size = samples_per_gpu
         num_workers = workers_per_gpu
diff --git a/mmseg/datasets/samplers/__init__.py b/mmseg/datasets/samplers/__init__.py
new file mode 100644
index 00000000..da09effa
--- /dev/null
+++ b/mmseg/datasets/samplers/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .distributed_sampler import DistributedSampler
+
+__all__ = ['DistributedSampler']
diff --git a/mmseg/datasets/samplers/distributed_sampler.py b/mmseg/datasets/samplers/distributed_sampler.py
new file mode 100644
index 00000000..d1a13c71
--- /dev/null
+++ b/mmseg/datasets/samplers/distributed_sampler.py
@@ -0,0 +1,71 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from __future__ import division
+from typing import Iterator, Optional
+
+import torch
+from torch.utils.data import Dataset
+from torch.utils.data import DistributedSampler as _DistributedSampler
+
+from mmseg.core.utils import sync_random_seed
+
+
+class DistributedSampler(_DistributedSampler):
+    """DistributedSampler inheriting from
+    `torch.utils.data.DistributedSampler`.
+
+    Args:
+        datasets (Dataset): the dataset will be loaded.
+        num_replicas (int, optional): Number of processes participating in
+            distributed training. By default, world_size is retrieved from the
+            current distributed group.
+        rank (int, optional):  Rank of the current process within num_replicas.
+            By default, rank is retrieved from the current distributed group.
+        shuffle (bool): If True (default), sampler will shuffle the indices.
+        seed (int): random seed used to shuffle the sampler if
+            :attr:`shuffle=True`. This number should be identical across all
+            processes in the distributed group. Default: ``0``.
+    """
+
+    def __init__(self,
+                 dataset: Dataset,
+                 num_replicas: Optional[int] = None,
+                 rank: Optional[int] = None,
+                 shuffle: bool = True,
+                 seed=0) -> None:
+        super().__init__(
+            dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
+
+        # In distributed sampling, different ranks should sample
+        # non-overlapped data in the dataset. Therefore, this function
+        # is used to make sure that each rank shuffles the data indices
+        # in the same order based on the same seed. Then different ranks
+        # could use different indices to select non-overlapped data from the
+        # same data list.
+        self.seed = sync_random_seed(seed)
+
+    def __iter__(self) -> Iterator:
+        """
+         Yields:
+            Iterator: iterator of indices for rank.
+        """
+        # deterministically shuffle based on epoch
+        if self.shuffle:
+            g = torch.Generator()
+            # When :attr:`shuffle=True`, this ensures all replicas
+            # use a different random ordering for each epoch.
+            # Otherwise, the next iteration of this sampler will
+            # yield the same ordering.
+            g.manual_seed(self.epoch + self.seed)
+            indices = torch.randperm(len(self.dataset), generator=g).tolist()
+        else:
+            indices = torch.arange(len(self.dataset)).tolist()
+
+        # add extra samples to make it evenly divisible
+        indices += indices[:(self.total_size - len(indices))]
+        assert len(indices) == self.total_size
+
+        # subsample
+        indices = indices[self.rank:self.total_size:self.num_replicas]
+        assert len(indices) == self.num_samples
+
+        return iter(indices)
-- 
GitLab