From da6bb2c8c54552f8529e254a8f363589f0e96b15 Mon Sep 17 00:00:00 2001 From: Rockey <41846794+RockeyCoss@users.noreply.github.com> Date: Thu, 10 Mar 2022 09:41:16 +0800 Subject: [PATCH] [Feature]: Add diff seeds to diff ranks and set torch seed in worker_init_fn (#1362) --- mmseg/datasets/builder.py | 1 + tools/train.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/mmseg/datasets/builder.py b/mmseg/datasets/builder.py index 3529ab92..21004e6e 100644 --- a/mmseg/datasets/builder.py +++ b/mmseg/datasets/builder.py @@ -186,3 +186,4 @@ def worker_init_fn(worker_id, num_workers, rank, seed): worker_seed = num_workers * rank + worker_id + seed np.random.seed(worker_seed) random.seed(worker_seed) + torch.manual_seed(worker_seed) diff --git a/tools/train.py b/tools/train.py index 1e1d01ac..6e7adc8d 100644 --- a/tools/train.py +++ b/tools/train.py @@ -8,6 +8,7 @@ import warnings import mmcv import torch +import torch.distributed as dist from mmcv.cnn.utils import revert_sync_batchnorm from mmcv.runner import get_dist_info, init_dist from mmcv.utils import Config, DictAction, get_git_hash @@ -50,6 +51,10 @@ def parse_args(): help='id of gpu to use ' '(only applicable to non-distributed training)') parser.add_argument('--seed', type=int, default=None, help='random seed') + parser.add_argument( + '--diff_seed', + action='store_true', + help='Whether or not set different seeds for different ranks') parser.add_argument( '--deterministic', action='store_true', @@ -180,6 +185,7 @@ def main(): # set random seeds seed = init_random_seed(args.seed) + seed = seed + dist.get_rank() if args.diff_seed else seed logger.info(f'Set random seed to {seed}, ' f'deterministic: {args.deterministic}') set_random_seed(seed, deterministic=args.deterministic) -- GitLab