diff --git a/configs/_base_/models/upernet_vit-b16_ln_mln.py b/configs/_base_/models/upernet_vit-b16_ln_mln.py
index 573612e13a63c5a7c635221435807e230b007a3a..1a5a56972908d018cadceb46d5e7efa2c109f6d2 100644
--- a/configs/_base_/models/upernet_vit-b16_ln_mln.py
+++ b/configs/_base_/models/upernet_vit-b16_ln_mln.py
@@ -21,7 +21,6 @@ model = dict(
         norm_cfg=dict(type='LN', eps=1e-6),
         act_cfg=dict(type='GELU'),
         norm_eval=False,
-        out_shape='NCHW',
         interpolate_mode='bicubic'),
     neck=dict(
         type='MultiLevelNeck',
diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py
index 1ad20a1ca6ac18b8583fbc7c3b419df4d33477ed..33176351ea00cd37c72e8e2c2323419c7369c183 100644
--- a/mmseg/models/backbones/vit.py
+++ b/mmseg/models/backbones/vit.py
@@ -118,8 +118,10 @@ class VisionTransformer(BaseModule):
         attn_drop_rate (float): The drop out rate for attention layer.
             Default 0.0
         drop_path_rate (float): stochastic depth rate. Default 0.0
-        with_cls_token (bool): If concatenating class token into image tokens
-            as transformer input. Default: True.
+        with_cls_token (bool): Whether concatenating class token into image
+            tokens as transformer input. Default: True.
+        output_cls_token (bool): Whether output the cls_token. If set True,
+            `with_cls_token` must be True. Default: False.
         norm_cfg (dict): Config dict for normalization layer.
             Default: dict(type='LN')
         act_cfg (dict): The activation config for FFNs.
@@ -128,8 +130,6 @@ class VisionTransformer(BaseModule):
             Default: False.
         final_norm (bool): Whether to add a additional layer to normalize
             final feature map. Default: False.
-        out_shape (str): Select the output format of feature information.
-            Default: NCHW.
         interpolate_mode (str): Select the interpolate mode for position
             embeding vector resize. Default: bicubic.
         num_fcs (int): The number of fully-connected layers for FFNs.
@@ -160,11 +160,11 @@ class VisionTransformer(BaseModule):
                  attn_drop_rate=0.,
                  drop_path_rate=0.,
                  with_cls_token=True,
+                 output_cls_token=False,
                  norm_cfg=dict(type='LN'),
                  act_cfg=dict(type='GELU'),
                  patch_norm=False,
                  final_norm=False,
-                 out_shape='NCHW',
                  interpolate_mode='bicubic',
                  num_fcs=2,
                  norm_eval=False,
@@ -185,8 +185,9 @@ class VisionTransformer(BaseModule):
 
         assert pretrain_style in ['timm', 'mmcls']
 
-        assert out_shape in ['NLC',
-                             'NCHW'], 'output shape must be "NLC" or "NCHW".'
+        if output_cls_token:
+            assert with_cls_token is True, f'with_cls_token must be True if' \
+                f'set output_cls_token to True, but got {with_cls_token}'
 
         if isinstance(pretrained, str) or pretrained is None:
             warnings.warn('DeprecationWarning: pretrained is a deprecated, '
@@ -196,7 +197,6 @@ class VisionTransformer(BaseModule):
 
         self.img_size = img_size
         self.patch_size = patch_size
-        self.out_shape = out_shape
         self.interpolate_mode = interpolate_mode
         self.norm_eval = norm_eval
         self.with_cp = with_cp
@@ -218,6 +218,7 @@ class VisionTransformer(BaseModule):
             (img_size[1] // patch_size)
 
         self.with_cls_token = with_cls_token
+        self.output_cls_token = output_cls_token
         self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
         self.pos_embed = nn.Parameter(
             torch.zeros(1, num_patches + 1, embed_dims))
@@ -253,7 +254,6 @@ class VisionTransformer(BaseModule):
                     batch_first=True))
 
         self.final_norm = final_norm
-        self.out_shape = out_shape
         if final_norm:
             self.norm1_name, norm1 = build_norm_layer(
                 norm_cfg, embed_dims, postfix=1)
@@ -290,8 +290,9 @@ class VisionTransformer(BaseModule):
                     pos_size = int(
                         math.sqrt(state_dict['pos_embed'].shape[1] - 1))
                     state_dict['pos_embed'] = self.resize_pos_embed(
-                        state_dict['pos_embed'], (h, w), (pos_size, pos_size),
-                        self.patch_size, self.interpolate_mode)
+                        state_dict['pos_embed'],
+                        (h // self.patch_size, w // self.patch_size),
+                        (pos_size, pos_size), self.interpolate_mode)
 
             self.load_state_dict(state_dict, False)
 
@@ -317,16 +318,15 @@ class VisionTransformer(BaseModule):
                     constant_init(m.bias, 0)
                     constant_init(m.weight, 1.0)
 
-    def _pos_embeding(self, img, patched_img, pos_embed):
+    def _pos_embeding(self, patched_img, hw_shape, pos_embed):
         """Positiong embeding method.
 
         Resize the pos_embed, if the input image size doesn't match
             the training size.
         Args:
-            img (torch.Tensor): The inference image tensor, the shape
-                must be [B, C, H, W].
             patched_img (torch.Tensor): The patched image, it should be
                 shape of [B, L1, C].
+            hw_shape (tuple): The downsampled image resolution.
             pos_embed (torch.Tensor): The pos_embed weighs, it should be
                 shape of [B, L2, c].
         Return:
@@ -344,36 +344,36 @@ class VisionTransformer(BaseModule):
                 raise ValueError(
                     'Unexpected shape of pos_embed, got {}.'.format(
                         pos_embed.shape))
-            pos_embed = self.resize_pos_embed(pos_embed, img.shape[2:],
-                                              (pos_h, pos_w), self.patch_size,
+            pos_embed = self.resize_pos_embed(pos_embed, hw_shape,
+                                              (pos_h, pos_w),
                                               self.interpolate_mode)
         return self.drop_after_pos(patched_img + pos_embed)
 
     @staticmethod
-    def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode):
+    def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
         """Resize pos_embed weights.
 
         Resize pos_embed using bicubic interpolate method.
         Args:
-            pos_embed (torch.Tensor): pos_embed weights.
-            input_shpae (tuple): Tuple for (input_h, intput_w).
-            pos_shape (tuple): Tuple for (pos_h, pos_w).
-            patch_size (int): Patch size.
+            pos_embed (torch.Tensor): Position embedding weights.
+            input_shpae (tuple): Tuple for (downsampled input image height,
+                downsampled input image width).
+            pos_shape (tuple): The resolution of downsampled origin training
+                image.
+            mode (str): Algorithm used for upsampling:
+                ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
+                ``'trilinear'``. Default: ``'nearest'``
         Return:
             torch.Tensor: The resized pos_embed of shape [B, L_new, C]
         """
         assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
-        input_h, input_w = input_shpae
         pos_h, pos_w = pos_shape
         cls_token_weight = pos_embed[:, 0]
         pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
         pos_embed_weight = pos_embed_weight.reshape(
             1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
         pos_embed_weight = F.interpolate(
-            pos_embed_weight,
-            size=[input_h // patch_size, input_w // patch_size],
-            align_corners=False,
-            mode=mode)
+            pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
         cls_token_weight = cls_token_weight.unsqueeze(1)
         pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
         pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
@@ -382,12 +382,12 @@ class VisionTransformer(BaseModule):
     def forward(self, inputs):
         B = inputs.shape[0]
 
-        x = self.patch_embed(inputs)
-
+        x, hw_shape = self.patch_embed(inputs), (self.patch_embed.DH,
+                                                 self.patch_embed.DW)
         # stole cls_tokens impl from Phil Wang, thanks
         cls_tokens = self.cls_token.expand(B, -1, -1)
         x = torch.cat((cls_tokens, x), dim=1)
-        x = self._pos_embeding(inputs, x, self.pos_embed)
+        x = self._pos_embeding(x, hw_shape, self.pos_embed)
 
         if not self.with_cls_token:
             # Remove class token for transformer encoder input
@@ -405,11 +405,11 @@ class VisionTransformer(BaseModule):
                     out = x[:, 1:]
                 else:
                     out = x
-                if self.out_shape == 'NCHW':
-                    B, _, C = out.shape
-                    out = out.reshape(B, inputs.shape[2] // self.patch_size,
-                                      inputs.shape[3] // self.patch_size,
-                                      C).permute(0, 3, 1, 2)
+                B, _, C = out.shape
+                out = out.reshape(B, hw_shape[0], hw_shape[1],
+                                  C).permute(0, 3, 1, 2)
+                if self.output_cls_token:
+                    out = [out, x[:, 0]]
                 outs.append(out)
 
         return tuple(outs)
diff --git a/tests/test_models/test_backbones/test_vit.py b/tests/test_models/test_backbones/test_vit.py
index 4577b97b86bb29f14d1bb72deb42c879f80e5192..16d6aba68fca014a1723fcc01dceec96b441d332 100644
--- a/tests/test_models/test_backbones/test_vit.py
+++ b/tests/test_models/test_backbones/test_vit.py
@@ -39,8 +39,8 @@ def test_vit_backbone():
         VisionTransformer(pretrained=123)
 
     with pytest.raises(AssertionError):
-        # out_shape must be 'NLC' or 'NCHW;'
-        VisionTransformer(out_shape='NCL')
+        # with_cls_token must be True when output_cls_token == True
+        VisionTransformer(with_cls_token=False, output_cls_token=True)
 
     # Test img_size isinstance tuple
     imgs = torch.randn(1, 3, 224, 224)
@@ -88,6 +88,11 @@ def test_vit_backbone():
     feat = model(imgs)
     assert feat[-1].shape == (1, 768, 7, 14)
 
+    # Test irregular input image
+    imgs = torch.randn(1, 3, 234, 345)
+    feat = model(imgs)
+    assert feat[-1].shape == (1, 768, 15, 22)
+
     # Test with_cp=True
     model = VisionTransformer(with_cp=True)
     imgs = torch.randn(1, 3, 224, 224)
@@ -100,12 +105,6 @@ def test_vit_backbone():
     feat = model(imgs)
     assert feat[-1].shape == (1, 768, 14, 14)
 
-    # Test out_shape == 'NLC'
-    model = VisionTransformer(out_shape='NLC')
-    imgs = torch.randn(1, 3, 224, 224)
-    feat = model(imgs)
-    assert feat[-1].shape == (1, 196, 768)
-
     # Test final norm
     model = VisionTransformer(final_norm=True)
     imgs = torch.randn(1, 3, 224, 224)
@@ -117,3 +116,10 @@ def test_vit_backbone():
     imgs = torch.randn(1, 3, 224, 224)
     feat = model(imgs)
     assert feat[-1].shape == (1, 768, 14, 14)
+
+    # Test output_cls_token
+    model = VisionTransformer(with_cls_token=True, output_cls_token=True)
+    imgs = torch.randn(1, 3, 224, 224)
+    feat = model(imgs)
+    assert feat[0][0].shape == (1, 768, 14, 14)
+    assert feat[0][1].shape == (1, 768)