From f15a21a30dac01053b84b5374690304f6dceb76d Mon Sep 17 00:00:00 2001 From: MengzhangLI <mcmong@pku.edu.cn> Date: Mon, 28 Mar 2022 23:50:39 +0800 Subject: [PATCH] [Enchance] Support random seed for distributed sampler (#1411) * support random seed for distributed sampler * move mmseg/utils/dist_util.py to mmseg/core/utils/dist_util.py * move mmseg/utils/dist_util.py to mmseg/core/utils/dist_util.py * change dist sampler * change dist sampler * fix docstring in sync_random_seed --- mmseg/core/utils/__init__.py | 6 +- mmseg/core/utils/dist_util.py | 46 ++++++++++++ mmseg/datasets/builder.py | 6 +- mmseg/datasets/samplers/__init__.py | 4 ++ .../datasets/samplers/distributed_sampler.py | 71 +++++++++++++++++++ 5 files changed, 130 insertions(+), 3 deletions(-) create mode 100644 mmseg/core/utils/dist_util.py create mode 100644 mmseg/datasets/samplers/__init__.py create mode 100644 mmseg/datasets/samplers/distributed_sampler.py diff --git a/mmseg/core/utils/__init__.py b/mmseg/core/utils/__init__.py index c8694b55..cb5a0c3f 100644 --- a/mmseg/core/utils/__init__.py +++ b/mmseg/core/utils/__init__.py @@ -1,6 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .dist_util import check_dist_init, sync_random_seed from .layer_decay_optimizer_constructor import \ LearningRateDecayOptimizerConstructor from .misc import add_prefix -__all__ = ['add_prefix', 'LearningRateDecayOptimizerConstructor'] +__all__ = [ + 'add_prefix', 'LearningRateDecayOptimizerConstructor', 'check_dist_init', + 'sync_random_seed' +] diff --git a/mmseg/core/utils/dist_util.py b/mmseg/core/utils/dist_util.py new file mode 100644 index 00000000..b3288519 --- /dev/null +++ b/mmseg/core/utils/dist_util.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +import torch.distributed as dist +from mmcv.runner import get_dist_info + + +def check_dist_init(): + return dist.is_available() and dist.is_initialized() + + +def sync_random_seed(seed=None, device='cuda'): + """Make sure different ranks share the same seed. All workers must call + this function, otherwise it will deadlock. This method is generally used in + `DistributedSampler`, because the seed should be identical across all + processes in the distributed group. + + In distributed sampling, different ranks should sample non-overlapped + data in the dataset. Therefore, this function is used to make sure that + each rank shuffles the data indices in the same order based + on the same seed. Then different ranks could use different indices + to select non-overlapped data from the same data list. + + Args: + seed (int, Optional): The seed. Default to None. + device (str): The device where the seed will be put on. + Default to 'cuda'. + Returns: + int: Seed to be used. + """ + + if seed is None: + seed = np.random.randint(2**31) + assert isinstance(seed, int) + + rank, world_size = get_dist_info() + + if world_size == 1: + return seed + + if rank == 0: + random_num = torch.tensor(seed, dtype=torch.int32, device=device) + else: + random_num = torch.tensor(0, dtype=torch.int32, device=device) + dist.broadcast(random_num, src=0) + return random_num.item() diff --git a/mmseg/datasets/builder.py b/mmseg/datasets/builder.py index 21004e6e..4d852d36 100644 --- a/mmseg/datasets/builder.py +++ b/mmseg/datasets/builder.py @@ -9,7 +9,9 @@ import torch from mmcv.parallel import collate from mmcv.runner import get_dist_info from mmcv.utils import Registry, build_from_cfg, digit_version -from torch.utils.data import DataLoader, DistributedSampler +from torch.utils.data import DataLoader + +from .samplers import DistributedSampler if platform.system() != 'Windows': # https://github.com/pytorch/pytorch/issues/973 @@ -129,7 +131,7 @@ def build_dataloader(dataset, rank, world_size = get_dist_info() if dist: sampler = DistributedSampler( - dataset, world_size, rank, shuffle=shuffle) + dataset, world_size, rank, shuffle=shuffle, seed=seed) shuffle = False batch_size = samples_per_gpu num_workers = workers_per_gpu diff --git a/mmseg/datasets/samplers/__init__.py b/mmseg/datasets/samplers/__init__.py new file mode 100644 index 00000000..da09effa --- /dev/null +++ b/mmseg/datasets/samplers/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .distributed_sampler import DistributedSampler + +__all__ = ['DistributedSampler'] diff --git a/mmseg/datasets/samplers/distributed_sampler.py b/mmseg/datasets/samplers/distributed_sampler.py new file mode 100644 index 00000000..d1a13c71 --- /dev/null +++ b/mmseg/datasets/samplers/distributed_sampler.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from __future__ import division +from typing import Iterator, Optional + +import torch +from torch.utils.data import Dataset +from torch.utils.data import DistributedSampler as _DistributedSampler + +from mmseg.core.utils import sync_random_seed + + +class DistributedSampler(_DistributedSampler): + """DistributedSampler inheriting from + `torch.utils.data.DistributedSampler`. + + Args: + datasets (Dataset): the dataset will be loaded. + num_replicas (int, optional): Number of processes participating in + distributed training. By default, world_size is retrieved from the + current distributed group. + rank (int, optional): Rank of the current process within num_replicas. + By default, rank is retrieved from the current distributed group. + shuffle (bool): If True (default), sampler will shuffle the indices. + seed (int): random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Default: ``0``. + """ + + def __init__(self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed=0) -> None: + super().__init__( + dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) + + # In distributed sampling, different ranks should sample + # non-overlapped data in the dataset. Therefore, this function + # is used to make sure that each rank shuffles the data indices + # in the same order based on the same seed. Then different ranks + # could use different indices to select non-overlapped data from the + # same data list. + self.seed = sync_random_seed(seed) + + def __iter__(self) -> Iterator: + """ + Yields: + Iterator: iterator of indices for rank. + """ + # deterministically shuffle based on epoch + if self.shuffle: + g = torch.Generator() + # When :attr:`shuffle=True`, this ensures all replicas + # use a different random ordering for each epoch. + # Otherwise, the next iteration of this sampler will + # yield the same ordering. + g.manual_seed(self.epoch + self.seed) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = torch.arange(len(self.dataset)).tolist() + + # add extra samples to make it evenly divisible + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) -- GitLab