From 2a9bf2d21bb48ba5f7461f6e107f0a7846697ad6 Mon Sep 17 00:00:00 2001 From: sennnnn <58427300+sennnnn@users.noreply.github.com> Date: Sun, 20 Jun 2021 06:53:13 +0800 Subject: [PATCH] [Fix] Fix some vit init bugs (#609) * [Fix] Fix vit init bug * Add some vit unit tests * Modify module import * Fix pretrain weights bug * Modify pretrained judge * Add some unit tests to improve code cov * Optimize code * Fix vit unit test --- mmseg/models/backbones/vit.py | 69 ++++++++++++-------- mmseg/models/utils/timm_convert.py | 1 - tests/test_models/test_backbones/test_vit.py | 35 +++++++++- 3 files changed, 76 insertions(+), 29 deletions(-) diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py index 774f555c..a0b945bb 100644 --- a/mmseg/models/backbones/vit.py +++ b/mmseg/models/backbones/vit.py @@ -1,4 +1,5 @@ import math +import warnings import torch import torch.nn as nn @@ -6,8 +7,7 @@ import torch.nn.functional as F from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init, kaiming_init, normal_init, trunc_normal_init) from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention -from mmcv.runner import _load_checkpoint -from mmcv.runner.base_module import BaseModule, ModuleList +from mmcv.runner import BaseModule, ModuleList, _load_checkpoint from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.utils import _pair as to_2tuple @@ -140,12 +140,6 @@ class PatchEmbed(BaseModule): self.norm = None def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - # assert H == self.img_size[0] and W == self.img_size[1], \ - # f"Input image size ({H}*{W}) doesn't " \ - # f'match model ({self.img_size[0]}*{self.img_size[1]}).' - # The output size is (B, N, D), where N=H*W/P/P, D is embid_dim x = self.projection(x).flatten(2).transpose(1, 2) if self.norm is not None: @@ -185,8 +179,12 @@ class VisionTransformer(BaseModule): Default: dict(type='LN') act_cfg (dict): The activation config for FFNs. Defalut: dict(type='GELU'). - final_norm (bool): Whether to add a additional layer to normalize + patch_norm (bool): Whether to add a norm in PatchEmbed Block. + Default: False. + final_norm (bool): Whether to add a additional layer to normalize final feature map. Default: False. + out_shape (str): Select the output format of feature information. + Default: NCHW. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Default: bicubic. num_fcs (int): The number of fully-connected layers for FFNs. @@ -198,6 +196,9 @@ class VisionTransformer(BaseModule): some memory while slowing down the training speed. Default: False. pretrain_style (str): Choose to use timm or mmcls pretrain weights. Default: timm. + pretrained (str, optional): model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. """ def __init__(self, @@ -216,12 +217,16 @@ class VisionTransformer(BaseModule): with_cls_token=True, norm_cfg=dict(type='LN'), act_cfg=dict(type='GELU'), + patch_norm=False, final_norm=False, + out_shape='NCHW', interpolate_mode='bicubic', num_fcs=2, norm_eval=False, with_cp=False, - pretrain_style='timm'): + pretrain_style='timm', + pretrained=None, + init_cfg=None): super(VisionTransformer, self).__init__() if isinstance(img_size, int): @@ -235,16 +240,32 @@ class VisionTransformer(BaseModule): assert pretrain_style in ['timm', 'mmcls'] - self.pretrain_style = pretrain_style + assert out_shape in ['NLC', + 'NCHW'], 'output shape must be "NLC" or "NCHW".' + + if isinstance(pretrained, str) or pretrained is None: + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + else: + raise TypeError('pretrained must be a str or None') + self.img_size = img_size self.patch_size = patch_size + self.out_shape = out_shape + self.interpolate_mode = interpolate_mode + self.norm_eval = norm_eval + self.with_cp = with_cp + self.pretrain_style = pretrain_style + self.pretrained = pretrained + self.init_cfg = init_cfg self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dims, - norm_cfg=norm_cfg) + norm_cfg=norm_cfg if patch_norm else None) + num_patches = self.patch_embed.num_patches self.with_cls_token = with_cls_token @@ -280,24 +301,20 @@ class VisionTransformer(BaseModule): norm_cfg=norm_cfg, batch_first=True)) - self.interpolate_mode = interpolate_mode self.final_norm = final_norm if final_norm: self.norm1_name, norm1 = build_norm_layer( norm_cfg, embed_dims, postfix=1) self.add_module(self.norm1_name, norm1) - self.norm_eval = norm_eval - self.with_cp = with_cp - @property def norm1(self): return getattr(self, self.norm1_name) - def init_weights(self, pretrained=None): - if isinstance(pretrained, str): + def init_weights(self): + if isinstance(self.pretrained, str): logger = get_root_logger() - checkpoint = _load_checkpoint(pretrained, logger=logger) + checkpoint = _load_checkpoint(self.pretrained, logger=logger) if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: @@ -325,7 +342,8 @@ class VisionTransformer(BaseModule): self.load_state_dict(state_dict, False) - elif pretrained is None: + elif self.pretrained is None: + super(VisionTransformer, self).init_weights() # We only implement the 'jax_impl' initialization implemented at # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 trunc_normal_init(self.pos_embed, std=.02) @@ -345,8 +363,6 @@ class VisionTransformer(BaseModule): elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): constant_init(m.bias, 0) constant_init(m.weight, 1.0) - else: - raise TypeError('pretrained must be a str or None') def _pos_embeding(self, img, patched_img, pos_embed): """Positiong embeding method. @@ -436,10 +452,11 @@ class VisionTransformer(BaseModule): out = x[:, 1:] else: out = x - B, _, C = out.shape - out = out.reshape(B, inputs.shape[2] // self.patch_size, - inputs.shape[3] // self.patch_size, - C).permute(0, 3, 1, 2) + if self.out_shape == 'NCHW': + B, _, C = out.shape + out = out.reshape(B, inputs.shape[2] // self.patch_size, + inputs.shape[3] // self.patch_size, + C).permute(0, 3, 1, 2) outs.append(out) return tuple(outs) diff --git a/mmseg/models/utils/timm_convert.py b/mmseg/models/utils/timm_convert.py index f9a4d311..2ce48b06 100644 --- a/mmseg/models/utils/timm_convert.py +++ b/mmseg/models/utils/timm_convert.py @@ -27,7 +27,6 @@ def vit_convert(timm_dict): new_k = new_k.replace('attn.proj', 'attn.attn.out_proj') else: new_k = k - new_k = f'backbone.{new_k}' mmseg_dict[new_k] = v return mmseg_dict diff --git a/tests/test_models/test_backbones/test_vit.py b/tests/test_models/test_backbones/test_vit.py index 452eee05..007781f2 100644 --- a/tests/test_models/test_backbones/test_vit.py +++ b/tests/test_models/test_backbones/test_vit.py @@ -24,21 +24,35 @@ def test_vit_backbone(): x = torch.randn(1, 196) VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear') - with pytest.raises(ValueError): + with pytest.raises(RuntimeError): # forward inputs must be [N, C, H, W] x = torch.randn(3, 30, 30) model = VisionTransformer() model(x) with pytest.raises(AssertionError): + # The length of img_size tuple must be lower than 3. VisionTransformer(img_size=(224, 224, 224)) + with pytest.raises(TypeError): + # Pretrained must be None or Str. + VisionTransformer(pretrained=123) + + with pytest.raises(AssertionError): + # out_shape must be 'NLC' or 'NCHW;' + VisionTransformer(out_shape='NCL') + # Test img_size isinstance tuple imgs = torch.randn(1, 3, 224, 224) - model = VisionTransformer(img_size=(224, 224)) + model = VisionTransformer(img_size=(224, )) model.init_weights() model(imgs) + # Test img_size isinstance tuple + imgs = torch.randn(1, 3, 224, 224) + model = VisionTransformer(img_size=(224, 224)) + model(imgs) + # Test norm_eval = True model = VisionTransformer(norm_eval=True) model.train() @@ -50,6 +64,11 @@ def test_vit_backbone(): assert check_norm_state(model.modules(), True) + # Test normal size input image + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert feat[-1].shape == (1, 768, 14, 14) + # Test large size input image imgs = torch.randn(1, 3, 256, 256) feat = model(imgs) @@ -81,8 +100,20 @@ def test_vit_backbone(): feat = model(imgs) assert feat[-1].shape == (1, 768, 14, 14) + # Test out_shape == 'NLC' + model = VisionTransformer(out_shape='NLC') + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert feat[-1].shape == (1, 196, 768) + # Test final norm model = VisionTransformer(final_norm=True) imgs = torch.randn(1, 3, 224, 224) feat = model(imgs) assert feat[-1].shape == (1, 768, 14, 14) + + # Test patch norm + model = VisionTransformer(patch_norm=True) + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert feat[-1].shape == (1, 768, 14, 14) -- GitLab