From bc27f2410900006b269d5305346646eed19d5069 Mon Sep 17 00:00:00 2001
From: Rockey <41846794+RockeyCoss@users.noreply.github.com>
Date: Thu, 4 Nov 2021 01:36:09 +0800
Subject: [PATCH] =?UTF-8?q?[Fix]=20Fix=20the=20bug=20that=20vit=20cannot?=
 =?UTF-8?q?=20load=20pretrain=20properly=20when=20using=20i=E2=80=A6=20(#9?=
 =?UTF-8?q?99)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* [Fix] Fix the bug that vit cannot load pretrain properly when using init_cfg to specify the pretrain scheme

* [Fix] fix the coverage problem

* Update mmseg/models/backbones/vit.py

Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>

* [Fix] make the predicate more concise and clearer

* [Fix] Modified the judgement logic

* Update tests/test_models/test_backbones/test_vit.py

Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>

* add comments

Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>
---
 mmseg/models/backbones/vit.py                | 22 ++++----
 tests/test_models/test_backbones/test_vit.py | 56 ++++++++++++++++++++
 2 files changed, 69 insertions(+), 9 deletions(-)

diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py
index 5cd3ff24..f5afbb7f 100644
--- a/mmseg/models/backbones/vit.py
+++ b/mmseg/models/backbones/vit.py
@@ -170,7 +170,7 @@ class VisionTransformer(BaseModule):
                  with_cp=False,
                  pretrained=None,
                  init_cfg=None):
-        super(VisionTransformer, self).__init__()
+        super(VisionTransformer, self).__init__(init_cfg=init_cfg)
 
         if isinstance(img_size, int):
             img_size = to_2tuple(img_size)
@@ -185,10 +185,13 @@ class VisionTransformer(BaseModule):
             assert with_cls_token is True, f'with_cls_token must be True if' \
                 f'set output_cls_token to True, but got {with_cls_token}'
 
-        if isinstance(pretrained, str) or pretrained is None:
-            warnings.warn('DeprecationWarning: pretrained is a deprecated, '
+        assert not (init_cfg and pretrained), \
+            'init_cfg and pretrained cannot be set at the same time'
+        if isinstance(pretrained, str):
+            warnings.warn('DeprecationWarning: pretrained is deprecated, '
                           'please use "init_cfg" instead')
-        else:
+            self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
+        elif pretrained is not None:
             raise TypeError('pretrained must be a str or None')
 
         self.img_size = img_size
@@ -197,7 +200,6 @@ class VisionTransformer(BaseModule):
         self.norm_eval = norm_eval
         self.with_cp = with_cp
         self.pretrained = pretrained
-        self.init_cfg = init_cfg
 
         self.patch_embed = PatchEmbed(
             in_channels=in_channels,
@@ -260,10 +262,12 @@ class VisionTransformer(BaseModule):
         return getattr(self, self.norm1_name)
 
     def init_weights(self):
-        if isinstance(self.pretrained, str):
+        if (isinstance(self.init_cfg, dict)
+                and self.init_cfg.get('type') == 'Pretrained'):
             logger = get_root_logger()
             checkpoint = _load_checkpoint(
-                self.pretrained, logger=logger, map_location='cpu')
+                self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
+
             if 'state_dict' in checkpoint:
                 state_dict = checkpoint['state_dict']
             else:
@@ -283,9 +287,9 @@ class VisionTransformer(BaseModule):
                         (pos_size, pos_size), self.interpolate_mode)
 
             self.load_state_dict(state_dict, False)
-
-        elif self.pretrained is None:
+        elif self.init_cfg is not None:
             super(VisionTransformer, self).init_weights()
+        else:
             # We only implement the 'jax_impl' initialization implemented at
             # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353  # noqa: E501
             trunc_normal_init(self.pos_embed, std=.02)
diff --git a/tests/test_models/test_backbones/test_vit.py b/tests/test_models/test_backbones/test_vit.py
index 5dbb51e6..4ce860c0 100644
--- a/tests/test_models/test_backbones/test_vit.py
+++ b/tests/test_models/test_backbones/test_vit.py
@@ -118,3 +118,59 @@ def test_vit_backbone():
     feat = model(imgs)
     assert feat[0][0].shape == (1, 768, 14, 14)
     assert feat[0][1].shape == (1, 768)
+
+
+def test_vit_init():
+    path = 'PATH_THAT_DO_NOT_EXIST'
+    # Test all combinations of pretrained and init_cfg
+    # pretrained=None, init_cfg=None
+    model = VisionTransformer(pretrained=None, init_cfg=None)
+    assert model.init_cfg is None
+    model.init_weights()
+
+    # pretrained=None
+    # init_cfg loads pretrain from an non-existent file
+    model = VisionTransformer(
+        pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
+    assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
+    # Test loading a checkpoint from an non-existent file
+    with pytest.raises(OSError):
+        model.init_weights()
+
+    # pretrained=None
+    # init_cfg=123, whose type is unsupported
+    model = VisionTransformer(pretrained=None, init_cfg=123)
+    with pytest.raises(TypeError):
+        model.init_weights()
+
+    # pretrained loads pretrain from an non-existent file
+    # init_cfg=None
+    model = VisionTransformer(pretrained=path, init_cfg=None)
+    assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
+    # Test loading a checkpoint from an non-existent file
+    with pytest.raises(OSError):
+        model.init_weights()
+
+    # pretrained loads pretrain from an non-existent file
+    # init_cfg loads pretrain from an non-existent file
+    with pytest.raises(AssertionError):
+        model = VisionTransformer(
+            pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
+    with pytest.raises(AssertionError):
+        model = VisionTransformer(pretrained=path, init_cfg=123)
+
+    # pretrain=123, whose type is unsupported
+    # init_cfg=None
+    with pytest.raises(TypeError):
+        model = VisionTransformer(pretrained=123, init_cfg=None)
+
+    # pretrain=123, whose type is unsupported
+    # init_cfg loads pretrain from an non-existent file
+    with pytest.raises(AssertionError):
+        model = VisionTransformer(
+            pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))
+
+    # pretrain=123, whose type is unsupported
+    # init_cfg=123, whose type is unsupported
+    with pytest.raises(AssertionError):
+        model = VisionTransformer(pretrained=123, init_cfg=123)
-- 
GitLab