From 84edf6c19029d5576e83929b9085172f85638696 Mon Sep 17 00:00:00 2001 From: Junjun2016 <hejunjun@sjtu.edu.cn> Date: Thu, 30 Sep 2021 22:50:44 +0800 Subject: [PATCH] fix load ckpt bug in swin (#928) --- mmseg/models/backbones/swin.py | 4 +++- tools/train.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mmseg/models/backbones/swin.py b/mmseg/models/backbones/swin.py index 9133d8ce..59f4616c 100644 --- a/mmseg/models/backbones/swin.py +++ b/mmseg/models/backbones/swin.py @@ -680,7 +680,7 @@ class SwinTransformer(BaseModule): f'`init_cfg` in ' \ f'{self.__class__.__name__} ' ckpt = _load_checkpoint( - self.init_cfg.checkpoint, logger=logger, map_location='cpu') + self.init_cfg['checkpoint'], logger=logger, map_location='cpu') if 'state_dict' in ckpt: _state_dict = ckpt['state_dict'] elif 'model' in ckpt: @@ -692,6 +692,8 @@ class SwinTransformer(BaseModule): for k, v in _state_dict.items(): if k.startswith('backbone.'): state_dict[k[9:]] = v + else: + state_dict[k] = v # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): diff --git a/tools/train.py b/tools/train.py index 05bd205c..208ca5ee 100644 --- a/tools/train.py +++ b/tools/train.py @@ -96,7 +96,7 @@ def main(): else: distributed = True init_dist(args.launcher, **cfg.dist_params) - # gpu_ids is used to calculate iter when resuming checkpoint, + # gpu_ids is used to calculate iter when resuming checkpoint _, world_size = get_dist_info() cfg.gpu_ids = range(world_size) -- GitLab