Skip to content
Snippets Groups Projects
Commit 2a9bf2d2 authored by sennnnn's avatar sennnnn Committed by GitHub
Browse files

[Fix] Fix some vit init bugs (#609)

* [Fix] Fix vit init bug

* Add some vit unit tests

* Modify module import

* Fix pretrain weights bug

* Modify pretrained judge

* Add some unit tests to improve code cov

* Optimize code

* Fix vit unit test
parent 9249dbae
No related branches found
No related tags found
No related merge requests found
import math
import warnings
import torch
import torch.nn as nn
......@@ -6,8 +7,7 @@ import torch.nn.functional as F
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
kaiming_init, normal_init, trunc_normal_init)
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
from mmcv.runner import _load_checkpoint
from mmcv.runner.base_module import BaseModule, ModuleList
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.utils import _pair as to_2tuple
......@@ -140,12 +140,6 @@ class PatchEmbed(BaseModule):
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
# assert H == self.img_size[0] and W == self.img_size[1], \
# f"Input image size ({H}*{W}) doesn't " \
# f'match model ({self.img_size[0]}*{self.img_size[1]}).'
# The output size is (B, N, D), where N=H*W/P/P, D is embid_dim
x = self.projection(x).flatten(2).transpose(1, 2)
if self.norm is not None:
......@@ -185,8 +179,12 @@ class VisionTransformer(BaseModule):
Default: dict(type='LN')
act_cfg (dict): The activation config for FFNs.
Defalut: dict(type='GELU').
final_norm (bool): Whether to add a additional layer to normalize
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
Default: False.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Default: False.
out_shape (str): Select the output format of feature information.
Default: NCHW.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Default: bicubic.
num_fcs (int): The number of fully-connected layers for FFNs.
......@@ -198,6 +196,9 @@ class VisionTransformer(BaseModule):
some memory while slowing down the training speed. Default: False.
pretrain_style (str): Choose to use timm or mmcls pretrain weights.
Default: timm.
pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
......@@ -216,12 +217,16 @@ class VisionTransformer(BaseModule):
with_cls_token=True,
norm_cfg=dict(type='LN'),
act_cfg=dict(type='GELU'),
patch_norm=False,
final_norm=False,
out_shape='NCHW',
interpolate_mode='bicubic',
num_fcs=2,
norm_eval=False,
with_cp=False,
pretrain_style='timm'):
pretrain_style='timm',
pretrained=None,
init_cfg=None):
super(VisionTransformer, self).__init__()
if isinstance(img_size, int):
......@@ -235,16 +240,32 @@ class VisionTransformer(BaseModule):
assert pretrain_style in ['timm', 'mmcls']
self.pretrain_style = pretrain_style
assert out_shape in ['NLC',
'NCHW'], 'output shape must be "NLC" or "NCHW".'
if isinstance(pretrained, str) or pretrained is None:
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
else:
raise TypeError('pretrained must be a str or None')
self.img_size = img_size
self.patch_size = patch_size
self.out_shape = out_shape
self.interpolate_mode = interpolate_mode
self.norm_eval = norm_eval
self.with_cp = with_cp
self.pretrain_style = pretrain_style
self.pretrained = pretrained
self.init_cfg = init_cfg
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=embed_dims,
norm_cfg=norm_cfg)
norm_cfg=norm_cfg if patch_norm else None)
num_patches = self.patch_embed.num_patches
self.with_cls_token = with_cls_token
......@@ -280,24 +301,20 @@ class VisionTransformer(BaseModule):
norm_cfg=norm_cfg,
batch_first=True))
self.interpolate_mode = interpolate_mode
self.final_norm = final_norm
if final_norm:
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
self.norm_eval = norm_eval
self.with_cp = with_cp
@property
def norm1(self):
return getattr(self, self.norm1_name)
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
def init_weights(self):
if isinstance(self.pretrained, str):
logger = get_root_logger()
checkpoint = _load_checkpoint(pretrained, logger=logger)
checkpoint = _load_checkpoint(self.pretrained, logger=logger)
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
......@@ -325,7 +342,8 @@ class VisionTransformer(BaseModule):
self.load_state_dict(state_dict, False)
elif pretrained is None:
elif self.pretrained is None:
super(VisionTransformer, self).init_weights()
# We only implement the 'jax_impl' initialization implemented at
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
trunc_normal_init(self.pos_embed, std=.02)
......@@ -345,8 +363,6 @@ class VisionTransformer(BaseModule):
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)
else:
raise TypeError('pretrained must be a str or None')
def _pos_embeding(self, img, patched_img, pos_embed):
"""Positiong embeding method.
......@@ -436,10 +452,11 @@ class VisionTransformer(BaseModule):
out = x[:, 1:]
else:
out = x
B, _, C = out.shape
out = out.reshape(B, inputs.shape[2] // self.patch_size,
inputs.shape[3] // self.patch_size,
C).permute(0, 3, 1, 2)
if self.out_shape == 'NCHW':
B, _, C = out.shape
out = out.reshape(B, inputs.shape[2] // self.patch_size,
inputs.shape[3] // self.patch_size,
C).permute(0, 3, 1, 2)
outs.append(out)
return tuple(outs)
......
......@@ -27,7 +27,6 @@ def vit_convert(timm_dict):
new_k = new_k.replace('attn.proj', 'attn.attn.out_proj')
else:
new_k = k
new_k = f'backbone.{new_k}'
mmseg_dict[new_k] = v
return mmseg_dict
......@@ -24,21 +24,35 @@ def test_vit_backbone():
x = torch.randn(1, 196)
VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear')
with pytest.raises(ValueError):
with pytest.raises(RuntimeError):
# forward inputs must be [N, C, H, W]
x = torch.randn(3, 30, 30)
model = VisionTransformer()
model(x)
with pytest.raises(AssertionError):
# The length of img_size tuple must be lower than 3.
VisionTransformer(img_size=(224, 224, 224))
with pytest.raises(TypeError):
# Pretrained must be None or Str.
VisionTransformer(pretrained=123)
with pytest.raises(AssertionError):
# out_shape must be 'NLC' or 'NCHW;'
VisionTransformer(out_shape='NCL')
# Test img_size isinstance tuple
imgs = torch.randn(1, 3, 224, 224)
model = VisionTransformer(img_size=(224, 224))
model = VisionTransformer(img_size=(224, ))
model.init_weights()
model(imgs)
# Test img_size isinstance tuple
imgs = torch.randn(1, 3, 224, 224)
model = VisionTransformer(img_size=(224, 224))
model(imgs)
# Test norm_eval = True
model = VisionTransformer(norm_eval=True)
model.train()
......@@ -50,6 +64,11 @@ def test_vit_backbone():
assert check_norm_state(model.modules(), True)
# Test normal size input image
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 14, 14)
# Test large size input image
imgs = torch.randn(1, 3, 256, 256)
feat = model(imgs)
......@@ -81,8 +100,20 @@ def test_vit_backbone():
feat = model(imgs)
assert feat[-1].shape == (1, 768, 14, 14)
# Test out_shape == 'NLC'
model = VisionTransformer(out_shape='NLC')
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[-1].shape == (1, 196, 768)
# Test final norm
model = VisionTransformer(final_norm=True)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 14, 14)
# Test patch norm
model = VisionTransformer(patch_norm=True)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 14, 14)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment