From 84edf6c19029d5576e83929b9085172f85638696 Mon Sep 17 00:00:00 2001
From: Junjun2016 <hejunjun@sjtu.edu.cn>
Date: Thu, 30 Sep 2021 22:50:44 +0800
Subject: [PATCH] fix load ckpt bug in swin (#928)

---
 mmseg/models/backbones/swin.py | 4 +++-
 tools/train.py                 | 2 +-
 2 files changed, 4 insertions(+), 2 deletions(-)

diff --git a/mmseg/models/backbones/swin.py b/mmseg/models/backbones/swin.py
index 9133d8ce..59f4616c 100644
--- a/mmseg/models/backbones/swin.py
+++ b/mmseg/models/backbones/swin.py
@@ -680,7 +680,7 @@ class SwinTransformer(BaseModule):
                                                   f'`init_cfg` in ' \
                                                   f'{self.__class__.__name__} '
             ckpt = _load_checkpoint(
-                self.init_cfg.checkpoint, logger=logger, map_location='cpu')
+                self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
             if 'state_dict' in ckpt:
                 _state_dict = ckpt['state_dict']
             elif 'model' in ckpt:
@@ -692,6 +692,8 @@ class SwinTransformer(BaseModule):
             for k, v in _state_dict.items():
                 if k.startswith('backbone.'):
                     state_dict[k[9:]] = v
+                else:
+                    state_dict[k] = v
 
             # strip prefix of state_dict
             if list(state_dict.keys())[0].startswith('module.'):
diff --git a/tools/train.py b/tools/train.py
index 05bd205c..208ca5ee 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -96,7 +96,7 @@ def main():
     else:
         distributed = True
         init_dist(args.launcher, **cfg.dist_params)
-        # gpu_ids is used to calculate iter when resuming checkpoint,
+        # gpu_ids is used to calculate iter when resuming checkpoint
         _, world_size = get_dist_info()
         cfg.gpu_ids = range(world_size)
 
-- 
GitLab