Skip to content
Snippets Groups Projects
Commit da6bb2c8 authored by Rockey's avatar Rockey Committed by GitHub
Browse files

[Feature]: Add diff seeds to diff ranks and set torch seed in worker_init_fn (#1362)

parent 98b8ed37
No related branches found
No related tags found
No related merge requests found
......@@ -186,3 +186,4 @@ def worker_init_fn(worker_id, num_workers, rank, seed):
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
torch.manual_seed(worker_seed)
......@@ -8,6 +8,7 @@ import warnings
import mmcv
import torch
import torch.distributed as dist
from mmcv.cnn.utils import revert_sync_batchnorm
from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import Config, DictAction, get_git_hash
......@@ -50,6 +51,10 @@ def parse_args():
help='id of gpu to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--diff_seed',
action='store_true',
help='Whether or not set different seeds for different ranks')
parser.add_argument(
'--deterministic',
action='store_true',
......@@ -180,6 +185,7 @@ def main():
# set random seeds
seed = init_random_seed(args.seed)
seed = seed + dist.get_rank() if args.diff_seed else seed
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment