From 2a9bf2d21bb48ba5f7461f6e107f0a7846697ad6 Mon Sep 17 00:00:00 2001
From: sennnnn <58427300+sennnnn@users.noreply.github.com>
Date: Sun, 20 Jun 2021 06:53:13 +0800
Subject: [PATCH] [Fix] Fix some vit init bugs (#609)

* [Fix] Fix vit init bug

* Add some vit unit tests

* Modify module import

* Fix pretrain weights bug

* Modify pretrained judge

* Add some unit tests to improve code cov

* Optimize code

* Fix vit unit test
---
 mmseg/models/backbones/vit.py                | 69 ++++++++++++--------
 mmseg/models/utils/timm_convert.py           |  1 -
 tests/test_models/test_backbones/test_vit.py | 35 +++++++++-
 3 files changed, 76 insertions(+), 29 deletions(-)

diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py
index 774f555c..a0b945bb 100644
--- a/mmseg/models/backbones/vit.py
+++ b/mmseg/models/backbones/vit.py
@@ -1,4 +1,5 @@
 import math
+import warnings
 
 import torch
 import torch.nn as nn
@@ -6,8 +7,7 @@ import torch.nn.functional as F
 from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
                       kaiming_init, normal_init, trunc_normal_init)
 from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
-from mmcv.runner import _load_checkpoint
-from mmcv.runner.base_module import BaseModule, ModuleList
+from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
 from torch.nn.modules.batchnorm import _BatchNorm
 from torch.nn.modules.utils import _pair as to_2tuple
 
@@ -140,12 +140,6 @@ class PatchEmbed(BaseModule):
             self.norm = None
 
     def forward(self, x):
-        B, C, H, W = x.shape
-        # FIXME look at relaxing size constraints
-        # assert H == self.img_size[0] and W == self.img_size[1], \
-        #     f"Input image size ({H}*{W}) doesn't " \
-        #     f'match model ({self.img_size[0]}*{self.img_size[1]}).'
-        # The output size is (B, N, D), where N=H*W/P/P, D is embid_dim
         x = self.projection(x).flatten(2).transpose(1, 2)
 
         if self.norm is not None:
@@ -185,8 +179,12 @@ class VisionTransformer(BaseModule):
             Default: dict(type='LN')
         act_cfg (dict): The activation config for FFNs.
             Defalut: dict(type='GELU').
-        final_norm (bool):  Whether to add a additional layer to normalize
+        patch_norm (bool): Whether to add a norm in PatchEmbed Block.
+            Default: False.
+        final_norm (bool): Whether to add a additional layer to normalize
             final feature map. Default: False.
+        out_shape (str): Select the output format of feature information.
+            Default: NCHW.
         interpolate_mode (str): Select the interpolate mode for position
             embeding vector resize. Default: bicubic.
         num_fcs (int): The number of fully-connected layers for FFNs.
@@ -198,6 +196,9 @@ class VisionTransformer(BaseModule):
             some memory while slowing down the training speed. Default: False.
         pretrain_style (str): Choose to use timm or mmcls pretrain weights.
             Default: timm.
+        pretrained (str, optional): model pretrained path. Default: None.
+        init_cfg (dict or list[dict], optional): Initialization config dict.
+            Default: None.
     """
 
     def __init__(self,
@@ -216,12 +217,16 @@ class VisionTransformer(BaseModule):
                  with_cls_token=True,
                  norm_cfg=dict(type='LN'),
                  act_cfg=dict(type='GELU'),
+                 patch_norm=False,
                  final_norm=False,
+                 out_shape='NCHW',
                  interpolate_mode='bicubic',
                  num_fcs=2,
                  norm_eval=False,
                  with_cp=False,
-                 pretrain_style='timm'):
+                 pretrain_style='timm',
+                 pretrained=None,
+                 init_cfg=None):
         super(VisionTransformer, self).__init__()
 
         if isinstance(img_size, int):
@@ -235,16 +240,32 @@ class VisionTransformer(BaseModule):
 
         assert pretrain_style in ['timm', 'mmcls']
 
-        self.pretrain_style = pretrain_style
+        assert out_shape in ['NLC',
+                             'NCHW'], 'output shape must be "NLC" or "NCHW".'
+
+        if isinstance(pretrained, str) or pretrained is None:
+            warnings.warn('DeprecationWarning: pretrained is a deprecated, '
+                          'please use "init_cfg" instead')
+        else:
+            raise TypeError('pretrained must be a str or None')
+
         self.img_size = img_size
         self.patch_size = patch_size
+        self.out_shape = out_shape
+        self.interpolate_mode = interpolate_mode
+        self.norm_eval = norm_eval
+        self.with_cp = with_cp
+        self.pretrain_style = pretrain_style
+        self.pretrained = pretrained
+        self.init_cfg = init_cfg
 
         self.patch_embed = PatchEmbed(
             img_size=img_size,
             patch_size=patch_size,
             in_channels=in_channels,
             embed_dim=embed_dims,
-            norm_cfg=norm_cfg)
+            norm_cfg=norm_cfg if patch_norm else None)
+
         num_patches = self.patch_embed.num_patches
 
         self.with_cls_token = with_cls_token
@@ -280,24 +301,20 @@ class VisionTransformer(BaseModule):
                     norm_cfg=norm_cfg,
                     batch_first=True))
 
-        self.interpolate_mode = interpolate_mode
         self.final_norm = final_norm
         if final_norm:
             self.norm1_name, norm1 = build_norm_layer(
                 norm_cfg, embed_dims, postfix=1)
             self.add_module(self.norm1_name, norm1)
 
-        self.norm_eval = norm_eval
-        self.with_cp = with_cp
-
     @property
     def norm1(self):
         return getattr(self, self.norm1_name)
 
-    def init_weights(self, pretrained=None):
-        if isinstance(pretrained, str):
+    def init_weights(self):
+        if isinstance(self.pretrained, str):
             logger = get_root_logger()
-            checkpoint = _load_checkpoint(pretrained, logger=logger)
+            checkpoint = _load_checkpoint(self.pretrained, logger=logger)
             if 'state_dict' in checkpoint:
                 state_dict = checkpoint['state_dict']
             elif 'model' in checkpoint:
@@ -325,7 +342,8 @@ class VisionTransformer(BaseModule):
 
             self.load_state_dict(state_dict, False)
 
-        elif pretrained is None:
+        elif self.pretrained is None:
+            super(VisionTransformer, self).init_weights()
             # We only implement the 'jax_impl' initialization implemented at
             # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353  # noqa: E501
             trunc_normal_init(self.pos_embed, std=.02)
@@ -345,8 +363,6 @@ class VisionTransformer(BaseModule):
                 elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
                     constant_init(m.bias, 0)
                     constant_init(m.weight, 1.0)
-        else:
-            raise TypeError('pretrained must be a str or None')
 
     def _pos_embeding(self, img, patched_img, pos_embed):
         """Positiong embeding method.
@@ -436,10 +452,11 @@ class VisionTransformer(BaseModule):
                     out = x[:, 1:]
                 else:
                     out = x
-                B, _, C = out.shape
-                out = out.reshape(B, inputs.shape[2] // self.patch_size,
-                                  inputs.shape[3] // self.patch_size,
-                                  C).permute(0, 3, 1, 2)
+                if self.out_shape == 'NCHW':
+                    B, _, C = out.shape
+                    out = out.reshape(B, inputs.shape[2] // self.patch_size,
+                                      inputs.shape[3] // self.patch_size,
+                                      C).permute(0, 3, 1, 2)
                 outs.append(out)
 
         return tuple(outs)
diff --git a/mmseg/models/utils/timm_convert.py b/mmseg/models/utils/timm_convert.py
index f9a4d311..2ce48b06 100644
--- a/mmseg/models/utils/timm_convert.py
+++ b/mmseg/models/utils/timm_convert.py
@@ -27,7 +27,6 @@ def vit_convert(timm_dict):
                 new_k = new_k.replace('attn.proj', 'attn.attn.out_proj')
         else:
             new_k = k
-        new_k = f'backbone.{new_k}'
         mmseg_dict[new_k] = v
 
     return mmseg_dict
diff --git a/tests/test_models/test_backbones/test_vit.py b/tests/test_models/test_backbones/test_vit.py
index 452eee05..007781f2 100644
--- a/tests/test_models/test_backbones/test_vit.py
+++ b/tests/test_models/test_backbones/test_vit.py
@@ -24,21 +24,35 @@ def test_vit_backbone():
         x = torch.randn(1, 196)
         VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear')
 
-    with pytest.raises(ValueError):
+    with pytest.raises(RuntimeError):
         # forward inputs must be [N, C, H, W]
         x = torch.randn(3, 30, 30)
         model = VisionTransformer()
         model(x)
 
     with pytest.raises(AssertionError):
+        # The length of img_size tuple must be lower than 3.
         VisionTransformer(img_size=(224, 224, 224))
 
+    with pytest.raises(TypeError):
+        # Pretrained must be None or Str.
+        VisionTransformer(pretrained=123)
+
+    with pytest.raises(AssertionError):
+        # out_shape must be 'NLC' or 'NCHW;'
+        VisionTransformer(out_shape='NCL')
+
     # Test img_size isinstance tuple
     imgs = torch.randn(1, 3, 224, 224)
-    model = VisionTransformer(img_size=(224, 224))
+    model = VisionTransformer(img_size=(224, ))
     model.init_weights()
     model(imgs)
 
+    # Test img_size isinstance tuple
+    imgs = torch.randn(1, 3, 224, 224)
+    model = VisionTransformer(img_size=(224, 224))
+    model(imgs)
+
     # Test norm_eval = True
     model = VisionTransformer(norm_eval=True)
     model.train()
@@ -50,6 +64,11 @@ def test_vit_backbone():
 
     assert check_norm_state(model.modules(), True)
 
+    # Test normal size input image
+    imgs = torch.randn(1, 3, 224, 224)
+    feat = model(imgs)
+    assert feat[-1].shape == (1, 768, 14, 14)
+
     # Test large size input image
     imgs = torch.randn(1, 3, 256, 256)
     feat = model(imgs)
@@ -81,8 +100,20 @@ def test_vit_backbone():
     feat = model(imgs)
     assert feat[-1].shape == (1, 768, 14, 14)
 
+    # Test out_shape == 'NLC'
+    model = VisionTransformer(out_shape='NLC')
+    imgs = torch.randn(1, 3, 224, 224)
+    feat = model(imgs)
+    assert feat[-1].shape == (1, 196, 768)
+
     # Test final norm
     model = VisionTransformer(final_norm=True)
     imgs = torch.randn(1, 3, 224, 224)
     feat = model(imgs)
     assert feat[-1].shape == (1, 768, 14, 14)
+
+    # Test patch norm
+    model = VisionTransformer(patch_norm=True)
+    imgs = torch.randn(1, 3, 224, 224)
+    feat = model(imgs)
+    assert feat[-1].shape == (1, 768, 14, 14)
-- 
GitLab