diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py
index b140700a9b8b1e5b816abdafc2932c997b3b9a91..774f555c49841e0562ae1d70f5ead99fd87f2e45 100644
--- a/mmseg/models/backbones/vit.py
+++ b/mmseg/models/backbones/vit.py
@@ -1,294 +1,257 @@
-"""Modified from https://github.com/rwightman/pytorch-image-
-models/blob/master/timm/models/vision_transformer.py."""
-
 import math
 
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-import torch.utils.checkpoint as cp
-from mmcv.cnn import (Conv2d, Linear, build_activation_layer, build_norm_layer,
-                      constant_init, kaiming_init, normal_init)
-from mmcv.runner import BaseModule, _load_checkpoint
-from mmcv.utils.parrots_wrapper import _BatchNorm
+from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
+                      kaiming_init, normal_init, trunc_normal_init)
+from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
+from mmcv.runner import _load_checkpoint
+from mmcv.runner.base_module import BaseModule, ModuleList
+from torch.nn.modules.batchnorm import _BatchNorm
+from torch.nn.modules.utils import _pair as to_2tuple
 
 from mmseg.utils import get_root_logger
 from ..builder import BACKBONES
-from ..utils import DropPath, trunc_normal_
+from ..utils import vit_convert
 
 
-class Mlp(nn.Module):
-    """MLP layer for Encoder block.
+class TransformerEncoderLayer(BaseModule):
+    """Implements one encoder layer in Vision Transformer.
 
     Args:
-        in_features(int): Input dimension for the first fully
-            connected layer.
-        hidden_features(int): Output dimension for the first fully
-            connected layer.
-        out_features(int): Output dementsion for the second fully
-            connected layer.
-        act_cfg(dict): Config dict for activation layer.
-            Default: dict(type='GELU').
-        drop(float): Drop rate for the dropout layer. Dropout rate has
-            to be between 0 and 1. Default: 0.
+        embed_dims (int): The feature dimension
+        num_heads (int): Parallel attention heads
+        feedforward_channels (int): The hidden dimension for FFNs
+        drop_rate (float): Probability of an element to be zeroed
+            after the feed forward layer. Default 0.0
+        attn_drop_rate (float): The drop out rate for attention layer.
+            Default 0.0
+        drop_path_rate (float): stochastic depth rate. Default 0.0.
+        num_fcs (int): The number of fully-connected layers for FFNs. Default 2
+        qkv_bias (bool): enable bias for qkv if True. Default True
+        act_cfg (dict): The activation config for FFNs. Defalut GELU
+        norm_cfg (dict): Config dict for normalization layer. Default
+            layer normalization
+        batch_first (bool): Key, Query and Value are shape of
+            (batch, n, embed_dim)
+            or (n, batch, embed_dim). Default to False.
+        init_cfg (dict, optional): Initialization config dict
     """
 
     def __init__(self,
-                 in_features,
-                 hidden_features=None,
-                 out_features=None,
-                 act_cfg=dict(type='GELU'),
-                 drop=0.):
-        super(Mlp, self).__init__()
-        out_features = out_features or in_features
-        hidden_features = hidden_features or in_features
-        self.fc1 = Linear(in_features, hidden_features)
-        self.act = build_activation_layer(act_cfg)
-        self.fc2 = Linear(hidden_features, out_features)
-        self.drop = nn.Dropout(drop)
-
-    def forward(self, x):
-        x = self.fc1(x)
-        x = self.act(x)
-        x = self.drop(x)
-        x = self.fc2(x)
-        x = self.drop(x)
-        return x
-
-
-class Attention(nn.Module):
-    """Attention layer for Encoder block.
-
-    Args:
-        dim (int): Dimension for the input vector.
-        num_heads (int): Number of parallel attention heads.
-        qkv_bias (bool): Enable bias for qkv if True. Default: False.
-        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
-        attn_drop (float): Drop rate for attention output weights.
-            Default: 0.
-        proj_drop (float): Drop rate for output weights. Default: 0.
-    """
-
-    def __init__(self,
-                 dim,
-                 num_heads=8,
-                 qkv_bias=False,
-                 qk_scale=None,
-                 attn_drop=0.,
-                 proj_drop=0.):
-        super(Attention, self).__init__()
-        self.num_heads = num_heads
-        head_dim = dim // num_heads
-        self.scale = qk_scale or head_dim**-0.5
-
-        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
-        self.attn_drop = nn.Dropout(attn_drop)
-        self.proj = Linear(dim, dim)
-        self.proj_drop = nn.Dropout(proj_drop)
-
-    def forward(self, x):
-        b, n, c = x.shape
-        qkv = self.qkv(x).reshape(b, n, 3, self.num_heads,
-                                  c // self.num_heads).permute(2, 0, 3, 1, 4)
-        q, k, v = qkv[0], qkv[1], qkv[2]
-
-        attn = (q @ k.transpose(-2, -1)) * self.scale
-        attn = attn.softmax(dim=-1)
-        attn = self.attn_drop(attn)
-
-        x = (attn @ v).transpose(1, 2).reshape(b, n, c)
-        x = self.proj(x)
-        x = self.proj_drop(x)
-        return x
-
-
-class Block(nn.Module):
-    """Implements encoder block with residual connection.
-
-    Args:
-        dim (int): The feature dimension.
-        num_heads (int): Number of parallel attention heads.
-        mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
-        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
-        drop (float): Drop rate for mlp output weights. Default: 0.
-        attn_drop (float): Drop rate for attention output weights.
-            Default: 0.
-        proj_drop (float): Drop rate for attn layer output weights.
-            Default: 0.
-        drop_path (float): Drop rate for paths of model.
-            Default: 0.
-        act_cfg (dict): Config dict for activation layer.
-            Default: dict(type='GELU').
-        norm_cfg (dict): Config dict for normalization layer.
-            Default: dict(type='LN', requires_grad=True).
-        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
-            memory while slowing down the training speed. Default: False.
-    """
-
-    def __init__(self,
-                 dim,
+                 embed_dims,
                  num_heads,
-                 mlp_ratio=4,
-                 qkv_bias=False,
-                 qk_scale=None,
-                 drop=0.,
-                 attn_drop=0.,
-                 proj_drop=0.,
-                 drop_path=0.,
+                 feedforward_channels,
+                 drop_rate=0.,
+                 attn_drop_rate=0.,
+                 drop_path_rate=0.,
+                 num_fcs=2,
+                 qkv_bias=True,
                  act_cfg=dict(type='GELU'),
-                 norm_cfg=dict(type='LN', eps=1e-6),
-                 with_cp=False):
-        super(Block, self).__init__()
-        self.with_cp = with_cp
-        _, self.norm1 = build_norm_layer(norm_cfg, dim)
-        self.attn = Attention(dim, num_heads, qkv_bias, qk_scale, attn_drop,
-                              proj_drop)
-        self.drop_path = DropPath(
-            drop_path) if drop_path > 0. else nn.Identity()
-        _, self.norm2 = build_norm_layer(norm_cfg, dim)
-        mlp_hidden_dim = int(dim * mlp_ratio)
-        self.mlp = Mlp(
-            in_features=dim,
-            hidden_features=mlp_hidden_dim,
-            act_cfg=act_cfg,
-            drop=drop)
+                 norm_cfg=dict(type='LN'),
+                 batch_first=False):
+        super(TransformerEncoderLayer, self).__init__()
+
+        self.norm1_name, norm1 = build_norm_layer(
+            norm_cfg, embed_dims, postfix=1)
+        self.add_module(self.norm1_name, norm1)
+
+        self.attn = MultiheadAttention(
+            embed_dims=embed_dims,
+            num_heads=num_heads,
+            attn_drop=attn_drop_rate,
+            proj_drop=drop_rate,
+            dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
+            batch_first=batch_first,
+            bias=qkv_bias)
+
+        self.norm2_name, norm2 = build_norm_layer(
+            norm_cfg, embed_dims, postfix=2)
+        self.add_module(self.norm2_name, norm2)
+
+        self.ffn = FFN(
+            embed_dims=embed_dims,
+            feedforward_channels=feedforward_channels,
+            num_fcs=num_fcs,
+            ffn_drop=drop_rate,
+            dropout_layer=None,
+            act_cfg=act_cfg)
+
+    @property
+    def norm1(self):
+        return getattr(self, self.norm1_name)
+
+    @property
+    def norm2(self):
+        return getattr(self, self.norm2_name)
 
     def forward(self, x):
-
-        def _inner_forward(x):
-            out = x + self.drop_path(self.attn(self.norm1(x)))
-            out = out + self.drop_path(self.mlp(self.norm2(out)))
-            return out
-
-        if self.with_cp and x.requires_grad:
-            out = cp.checkpoint(_inner_forward, x)
-        else:
-            out = _inner_forward(x)
-
-        return out
+        x = self.attn(self.norm1(x), identity=x)
+        x = self.ffn(self.norm2(x), identity=x)
+        return x
 
 
-class PatchEmbed(nn.Module):
+# Modified from pytorch-image-models
+class PatchEmbed(BaseModule):
     """Image to Patch Embedding.
 
     Args:
-        img_size (int | tuple): Input image size.
-            default: 224.
-        patch_size (int): Width and height for a patch.
-            default: 16.
-        in_channels (int): Input channels for images. Default: 3.
-        embed_dim (int): The embedding dimension. Default: 768.
+        img_size (int | tuple): The size of input image.
+        patch_size (int): The size of one patch
+        in_channels (int): The num of input channels.
+        embed_dim (int): The dimensions of embedding.
+        norm_cfg (dict, optional): Config dict for normalization layer.
+        conv_cfg (dict, optional): The config dict for conv layers.
+            Default: None.
     """
 
     def __init__(self,
                  img_size=224,
                  patch_size=16,
                  in_channels=3,
-                 embed_dim=768):
+                 embed_dim=768,
+                 norm_cfg=None,
+                 conv_cfg=None):
         super(PatchEmbed, self).__init__()
-        if isinstance(img_size, int):
-            self.img_size = (img_size, img_size)
-        elif isinstance(img_size, tuple):
-            self.img_size = img_size
+
+        self.img_size = img_size
+        self.patch_size = to_2tuple(patch_size)
+
+        patches_resolution = [
+            img_size[0] // self.patch_size[0],
+            img_size[1] // self.patch_size[1]
+        ]
+        num_patches = patches_resolution[0] * patches_resolution[1]
+        self.patches_resolution = patches_resolution
+        self.num_patches = num_patches
+
+        # Use conv layer to embed
+        self.projection = build_conv_layer(
+            conv_cfg,
+            in_channels,
+            embed_dim,
+            kernel_size=patch_size,
+            stride=patch_size)
+
+        if norm_cfg is not None:
+            self.norm = build_norm_layer(norm_cfg, embed_dim)[1]
         else:
-            raise TypeError('img_size must be type of int or tuple')
-        h, w = self.img_size
-        self.patch_size = (patch_size, patch_size)
-        self.num_patches = (h // patch_size) * (w // patch_size)
-        self.proj = Conv2d(
-            in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
+            self.norm = None
 
     def forward(self, x):
-        return self.proj(x).flatten(2).transpose(1, 2)
+        B, C, H, W = x.shape
+        # FIXME look at relaxing size constraints
+        # assert H == self.img_size[0] and W == self.img_size[1], \
+        #     f"Input image size ({H}*{W}) doesn't " \
+        #     f'match model ({self.img_size[0]}*{self.img_size[1]}).'
+        # The output size is (B, N, D), where N=H*W/P/P, D is embid_dim
+        x = self.projection(x).flatten(2).transpose(1, 2)
+
+        if self.norm is not None:
+            x = self.norm(x)
+
+        return x
 
 
 @BACKBONES.register_module()
 class VisionTransformer(BaseModule):
-    """Vision transformer backbone.
+    """Vision Transformer.
 
-    A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for
-        Image Recognition at Scale` - https://arxiv.org/abs/2010.11929
+    A PyTorch implement of : `An Image is Worth 16x16 Words:
+    Transformers for Image Recognition at Scale` -
+        https://arxiv.org/abs/2010.11929
 
     Args:
-        img_size (tuple): input image size. Default: (224, 224).
-        patch_size (int, tuple): patch size. Default: 16.
-        in_channels (int): number of input channels. Default: 3.
-        embed_dim (int): embedding dimension. Default: 768.
-        depth (int): depth of transformer. Default: 12.
+        img_size (int | tuple): Input image size. Default: 224.
+        patch_size (int): The patch size. Default: 16.
+        in_channels (int): Number of input channels. Default: 3.
+        embed_dims (int): embedding dimension. Default: 768.
+        num_layers (int): depth of transformer. Default: 12.
         num_heads (int): number of attention heads. Default: 12.
         mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
             Default: 4.
         out_indices (list | tuple | int): Output from which stages.
             Default: -1.
         qkv_bias (bool): enable bias for qkv if True. Default: True.
-        qk_scale (float): override default qk scale of head_dim ** -0.5 if set.
-        drop_rate (float): dropout rate. Default: 0.
-        attn_drop_rate (float): attention dropout rate. Default: 0.
-        drop_path_rate (float): Rate of DropPath. Default: 0.
+        drop_rate (float): Probability of an element to be zeroed.
+            Default 0.0
+        attn_drop_rate (float): The drop out rate for attention layer.
+            Default 0.0
+        drop_path_rate (float): stochastic depth rate. Default 0.0
+        with_cls_token (bool): If concatenating class token into image tokens
+            as transformer input. Default: True.
         norm_cfg (dict): Config dict for normalization layer.
-            Default: dict(type='LN', eps=1e-6, requires_grad=True).
-        act_cfg (dict): Config dict for activation layer.
-            Default: dict(type='GELU').
-        norm_eval (bool): Whether to set norm layers to eval mode, namely,
-            freeze running stats (mean and var). Note: Effect on Batch Norm
-            and its variants only. Default: False.
+            Default: dict(type='LN')
+        act_cfg (dict): The activation config for FFNs.
+            Defalut: dict(type='GELU').
         final_norm (bool):  Whether to add a additional layer to normalize
             final feature map. Default: False.
-        out_reshape (str): Select the output format of feature information.
-            Default: NCHW.
         interpolate_mode (str): Select the interpolate mode for position
             embeding vector resize. Default: bicubic.
-        with_cls_token (bool): If concatenating class token into image tokens
-            as transformer input. Default: True.
-        with_cp (bool): Use checkpoint or not. Using checkpoint
-            will save some memory while slowing down the training speed.
-            Default: False.
-        pretrained (str, optional): model pretrained path. Default: None
-        init_cfg (dict or list[dict], optional): Initialization config dict.
-            Default: None
+        num_fcs (int): The number of fully-connected layers for FFNs.
+            Default: 2.
+        norm_eval (bool): Whether to set norm layers to eval mode, namely,
+            freeze running stats (mean and var). Note: Effect on Batch Norm
+            and its variants only. Default: False.
+        with_cp (bool): Use checkpoint or not. Using checkpoint will save
+            some memory while slowing down the training speed. Default: False.
+        pretrain_style (str): Choose to use timm or mmcls pretrain weights.
+            Default: timm.
     """
 
     def __init__(self,
-                 img_size=(224, 224),
+                 img_size=224,
                  patch_size=16,
                  in_channels=3,
-                 embed_dim=768,
-                 depth=12,
+                 embed_dims=768,
+                 num_layers=12,
                  num_heads=12,
                  mlp_ratio=4,
                  out_indices=11,
                  qkv_bias=True,
-                 qk_scale=None,
                  drop_rate=0.,
                  attn_drop_rate=0.,
                  drop_path_rate=0.,
-                 norm_cfg=dict(type='LN', eps=1e-6, requires_grad=True),
+                 with_cls_token=True,
+                 norm_cfg=dict(type='LN'),
                  act_cfg=dict(type='GELU'),
-                 norm_eval=False,
                  final_norm=False,
-                 out_shape='NCHW',
-                 with_cls_token=True,
                  interpolate_mode='bicubic',
+                 num_fcs=2,
+                 norm_eval=False,
                  with_cp=False,
-                 pretrained=None,
-                 init_cfg=None):
-        super(VisionTransformer, self).__init__(init_cfg)
-        self.pretrained = pretrained
+                 pretrain_style='timm'):
+        super(VisionTransformer, self).__init__()
+
+        if isinstance(img_size, int):
+            img_size = to_2tuple(img_size)
+        elif isinstance(img_size, tuple):
+            if len(img_size) == 1:
+                img_size = to_2tuple(img_size[0])
+            assert len(img_size) == 2, \
+                f'The size of image should have length 1 or 2, ' \
+                f'but got {len(img_size)}'
 
+        assert pretrain_style in ['timm', 'mmcls']
+
+        self.pretrain_style = pretrain_style
         self.img_size = img_size
         self.patch_size = patch_size
-        self.features = self.embed_dim = embed_dim
+
         self.patch_embed = PatchEmbed(
             img_size=img_size,
             patch_size=patch_size,
             in_channels=in_channels,
-            embed_dim=embed_dim)
+            embed_dim=embed_dims,
+            norm_cfg=norm_cfg)
+        num_patches = self.patch_embed.num_patches
 
         self.with_cls_token = with_cls_token
-        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
         self.pos_embed = nn.Parameter(
-            torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
-        self.pos_drop = nn.Dropout(p=drop_rate)
+            torch.zeros(1, num_patches + 1, embed_dims))
+        self.drop_after_pos = nn.Dropout(p=drop_rate)
 
         if isinstance(out_indices, int):
             self.out_indices = [out_indices]
@@ -297,37 +260,41 @@ class VisionTransformer(BaseModule):
         else:
             raise TypeError('out_indices must be type of int, list or tuple')
 
-        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
-               ]  # stochastic depth decay rule
-        self.blocks = nn.ModuleList([
-            Block(
-                dim=embed_dim,
-                num_heads=num_heads,
-                mlp_ratio=mlp_ratio,
-                qkv_bias=qkv_bias,
-                qk_scale=qk_scale,
-                drop=dpr[i],
-                attn_drop=attn_drop_rate,
-                act_cfg=act_cfg,
-                norm_cfg=norm_cfg,
-                with_cp=with_cp) for i in range(depth)
-        ])
-
-        assert out_shape in ['NLC',
-                             'NCHW'], 'output shape must be "NLC" or "NCHW".'
-
-        self.out_shape = out_shape
+        dpr = [
+            x.item() for x in torch.linspace(0, drop_path_rate, num_layers)
+        ]  # stochastic depth decay rule
+
+        self.layers = ModuleList()
+        for i in range(num_layers):
+            self.layers.append(
+                TransformerEncoderLayer(
+                    embed_dims=embed_dims,
+                    num_heads=num_heads,
+                    feedforward_channels=mlp_ratio * embed_dims,
+                    attn_drop_rate=attn_drop_rate,
+                    drop_rate=drop_rate,
+                    drop_path_rate=dpr[i],
+                    num_fcs=num_fcs,
+                    qkv_bias=qkv_bias,
+                    act_cfg=act_cfg,
+                    norm_cfg=norm_cfg,
+                    batch_first=True))
 
         self.interpolate_mode = interpolate_mode
         self.final_norm = final_norm
         if final_norm:
-            _, self.norm = build_norm_layer(norm_cfg, embed_dim)
+            self.norm1_name, norm1 = build_norm_layer(
+                norm_cfg, embed_dims, postfix=1)
+            self.add_module(self.norm1_name, norm1)
 
         self.norm_eval = norm_eval
         self.with_cp = with_cp
 
-    def init_weights(self):
-        pretrained = self.pretrained
+    @property
+    def norm1(self):
+        return getattr(self, self.norm1_name)
+
+    def init_weights(self, pretrained=None):
         if isinstance(pretrained, str):
             logger = get_root_logger()
             checkpoint = _load_checkpoint(pretrained, logger=logger)
@@ -338,10 +305,17 @@ class VisionTransformer(BaseModule):
             else:
                 state_dict = checkpoint
 
+            if self.pretrain_style == 'timm':
+                # Because the refactor of vit is blocked by mmcls,
+                # so we firstly use timm pretrain weights to train
+                # downstream model.
+                state_dict = vit_convert(state_dict)
+
             if 'pos_embed' in state_dict.keys():
                 if self.pos_embed.shape != state_dict['pos_embed'].shape:
-                    logger.info(msg=f'Resize the pos_embed shape from \
-{state_dict["pos_embed"].shape} to {self.pos_embed.shape}')
+                    logger.info(msg=f'Resize the pos_embed shape from '
+                                f'{state_dict["pos_embed"].shape} to '
+                                f'{self.pos_embed.shape}')
                     h, w = self.img_size
                     pos_size = int(
                         math.sqrt(state_dict['pos_embed'].shape[1] - 1))
@@ -354,17 +328,17 @@ class VisionTransformer(BaseModule):
         elif pretrained is None:
             # We only implement the 'jax_impl' initialization implemented at
             # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353  # noqa: E501
-            trunc_normal_(self.pos_embed, std=.02)
-            trunc_normal_(self.cls_token, std=.02)
+            trunc_normal_init(self.pos_embed, std=.02)
+            trunc_normal_init(self.cls_token, std=.02)
             for n, m in self.named_modules():
-                if isinstance(m, Linear):
-                    trunc_normal_(m.weight, std=.02)
+                if isinstance(m, nn.Linear):
+                    trunc_normal_init(m.weight, std=.02)
                     if m.bias is not None:
-                        if 'mlp' in n:
+                        if 'ffn' in n:
                             normal_init(m.bias, std=1e-6)
                         else:
                             constant_init(m.bias, 0)
-                elif isinstance(m, Conv2d):
+                elif isinstance(m, nn.Conv2d):
                     kaiming_init(m.weight, mode='fan_in')
                     if m.bias is not None:
                         constant_init(m.bias, 0)
@@ -404,7 +378,7 @@ class VisionTransformer(BaseModule):
             pos_embed = self.resize_pos_embed(pos_embed, img.shape[2:],
                                               (pos_h, pos_w), self.patch_size,
                                               self.interpolate_mode)
-        return self.pos_drop(patched_img + pos_embed)
+        return self.drop_after_pos(patched_img + pos_embed)
 
     @staticmethod
     def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode):
@@ -441,31 +415,31 @@ class VisionTransformer(BaseModule):
 
         x = self.patch_embed(inputs)
 
+        # stole cls_tokens impl from Phil Wang, thanks
         cls_tokens = self.cls_token.expand(B, -1, -1)
         x = torch.cat((cls_tokens, x), dim=1)
         x = self._pos_embeding(inputs, x, self.pos_embed)
 
         if not self.with_cls_token:
-            # Remove class token for transformer input
+            # Remove class token for transformer encoder input
             x = x[:, 1:]
 
         outs = []
-        for i, blk in enumerate(self.blocks):
-            x = blk(x)
-            if i == len(self.blocks) - 1:
+        for i, layer in enumerate(self.layers):
+            x = layer(x)
+            if i == len(self.layers) - 1:
                 if self.final_norm:
-                    x = self.norm(x)
+                    x = self.norm1(x)
             if i in self.out_indices:
                 if self.with_cls_token:
                     # Remove class token and reshape token for decoder head
                     out = x[:, 1:]
                 else:
                     out = x
-                if self.out_shape == 'NCHW':
-                    B, _, C = out.shape
-                    out = out.reshape(B, inputs.shape[2] // self.patch_size,
-                                      inputs.shape[3] // self.patch_size,
-                                      C).permute(0, 3, 1, 2)
+                B, _, C = out.shape
+                out = out.reshape(B, inputs.shape[2] // self.patch_size,
+                                  inputs.shape[3] // self.patch_size,
+                                  C).permute(0, 3, 1, 2)
                 outs.append(out)
 
         return tuple(outs)
diff --git a/mmseg/models/utils/__init__.py b/mmseg/models/utils/__init__.py
index 3d3bdd349b9f2ae499a2fcb2ac1d2e3c77befebe..be11d77f4e7b193d8fc006a8672fe50059468ee6 100644
--- a/mmseg/models/utils/__init__.py
+++ b/mmseg/models/utils/__init__.py
@@ -4,10 +4,10 @@ from .make_divisible import make_divisible
 from .res_layer import ResLayer
 from .se_layer import SELayer
 from .self_attention_block import SelfAttentionBlock
+from .timm_convert import vit_convert
 from .up_conv_block import UpConvBlock
-from .weight_init import trunc_normal_
 
 __all__ = [
     'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
-    'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'DropPath', 'trunc_normal_'
+    'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'DropPath', 'vit_convert'
 ]
diff --git a/mmseg/models/utils/timm_convert.py b/mmseg/models/utils/timm_convert.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9a4d311090339aba71fd79271542c0ab76bbbc9
--- /dev/null
+++ b/mmseg/models/utils/timm_convert.py
@@ -0,0 +1,33 @@
+from collections import OrderedDict
+
+
+def vit_convert(timm_dict):
+
+    mmseg_dict = OrderedDict()
+
+    for k, v in timm_dict.items():
+        if k.startswith('head'):
+            continue
+        if k.startswith('norm'):
+            new_k = k.replace('norm.', 'ln1.')
+        elif k.startswith('patch_embed'):
+            if 'proj' in k:
+                new_k = k.replace('proj', 'projection')
+        elif k.startswith('blocks'):
+            new_k = k.replace('blocks.', 'layers.')
+            if 'norm' in new_k:
+                new_k = new_k.replace('norm', 'ln')
+            elif 'mlp.fc1' in new_k:
+                new_k = new_k.replace('mlp.fc1', 'ffn.layers.0.0')
+            elif 'mlp.fc2' in new_k:
+                new_k = new_k.replace('mlp.fc2', 'ffn.layers.1')
+            elif 'attn.qkv' in new_k:
+                new_k = new_k.replace('attn.qkv.', 'attn.attn.in_proj_')
+            elif 'attn.proj' in new_k:
+                new_k = new_k.replace('attn.proj', 'attn.attn.out_proj')
+        else:
+            new_k = k
+        new_k = f'backbone.{new_k}'
+        mmseg_dict[new_k] = v
+
+    return mmseg_dict
diff --git a/mmseg/models/utils/weight_init.py b/mmseg/models/utils/weight_init.py
deleted file mode 100644
index 38141ba3d61f64ddfc0a31574b4648cbad96d7dd..0000000000000000000000000000000000000000
--- a/mmseg/models/utils/weight_init.py
+++ /dev/null
@@ -1,62 +0,0 @@
-"""Modified from https://github.com/rwightman/pytorch-image-
-models/blob/master/timm/models/layers/drop.py."""
-
-import math
-import warnings
-
-import torch
-
-
-def _no_grad_trunc_normal_(tensor, mean, std, a, b):
-    """Reference: https://people.sc.fsu.edu/~jburkardt/presentations
-    /truncated_normal.pdf"""
-
-    def norm_cdf(x):
-        # Computes standard normal cumulative distribution function
-        return (1. + math.erf(x / math.sqrt(2.))) / 2.
-
-    if (mean < a - 2 * std) or (mean > b + 2 * std):
-        warnings.warn(
-            'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
-            'The distribution of values may be incorrect.',
-            stacklevel=2)
-
-    with torch.no_grad():
-        # Values are generated by using a truncated uniform distribution and
-        # then using the inverse CDF for the normal distribution.
-        # Get upper and lower cdf values
-        lower_bound = norm_cdf((a - mean) / std)
-        upper_bound = norm_cdf((b - mean) / std)
-
-        # Uniformly fill tensor with values from [l, u], then translate to
-        # [2l-1, 2u-1].
-        tensor.uniform_(2 * lower_bound - 1, 2 * upper_bound - 1)
-
-        # Use inverse cdf transform for normal distribution to get truncated
-        # standard normal
-        tensor.erfinv_()
-
-        # Transform to proper mean, std
-        tensor.mul_(std * math.sqrt(2.))
-        tensor.add_(mean)
-
-        # Clamp to ensure it's in the proper range
-        tensor.clamp_(min=a, max=b)
-        return tensor
-
-
-def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
-    r"""Fills the input Tensor with values drawn from a truncated
-    normal distribution. The values are effectively drawn from the
-    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
-    with values outside :math:`[a, b]` redrawn until they are within
-    the bounds. The method used for generating the random values works
-    best when :math:`a \leq \text{mean} \leq b`.
-    Args:
-        tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`
-        mean (float): the mean of the normal distribution
-        std (float): the standard deviation of the normal distribution
-        a (float): the minimum cutoff value
-        b (float): the maximum cutoff value
-    """
-    return _no_grad_trunc_normal_(tensor, mean, std, a, b)
diff --git a/tests/test_models/test_backbones/test_vit.py b/tests/test_models/test_backbones/test_vit.py
index 1ec42d34ea4a80a0b46aa2d2e192ed48248c5926..452eee05d8ccb4264c0810b4d399f0734888ce30 100644
--- a/tests/test_models/test_backbones/test_vit.py
+++ b/tests/test_models/test_backbones/test_vit.py
@@ -24,19 +24,18 @@ def test_vit_backbone():
         x = torch.randn(1, 196)
         VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear')
 
-    with pytest.raises(RuntimeError):
+    with pytest.raises(ValueError):
         # forward inputs must be [N, C, H, W]
         x = torch.randn(3, 30, 30)
         model = VisionTransformer()
         model(x)
 
     with pytest.raises(AssertionError):
-        # out_shape must be 'NLC' or 'NCHW;'
-        VisionTransformer(out_shape='NCL')
+        VisionTransformer(img_size=(224, 224, 224))
 
-    # Test img_size isinstance int
+    # Test img_size isinstance tuple
     imgs = torch.randn(1, 3, 224, 224)
-    model = VisionTransformer(img_size=224)
+    model = VisionTransformer(img_size=(224, 224))
     model.init_weights()
     model(imgs)
 
@@ -65,6 +64,11 @@ def test_vit_backbone():
     feat = model(imgs)
     assert feat[-1].shape == (1, 768, 14, 14)
 
+    # Test unbalanced size input image
+    imgs = torch.randn(1, 3, 112, 224)
+    feat = model(imgs)
+    assert feat[-1].shape == (1, 768, 7, 14)
+
     # Test with_cp=True
     model = VisionTransformer(with_cp=True)
     imgs = torch.randn(1, 3, 224, 224)
@@ -77,8 +81,8 @@ def test_vit_backbone():
     feat = model(imgs)
     assert feat[-1].shape == (1, 768, 14, 14)
 
-    # Test final reshape arg
+    # Test final norm
+    model = VisionTransformer(final_norm=True)
     imgs = torch.randn(1, 3, 224, 224)
-    model = VisionTransformer(out_shape='NLC')
     feat = model(imgs)
-    assert feat[-1].shape == (1, 196, 768)
+    assert feat[-1].shape == (1, 768, 14, 14)