From 5c195db1bd1eb5f4903d7d39d5af4a82f6a435d4 Mon Sep 17 00:00:00 2001
From: sennnnn <58427300+sennnnn@users.noreply.github.com>
Date: Thu, 6 May 2021 13:49:28 +0800
Subject: [PATCH] Add option for output shape of ViT (#530)

* Add arg: final_reshape to control if converting output feature information from NLC to NCHW;

* Fix the default value of final_reshape;

* Modify arg: final_reshape to arg: out_shape;

* Fix some unit test bug;
---
 mmseg/models/backbones/vit.py                | 17 +++++++++++++----
 tests/test_models/test_backbones/test_vit.py | 10 ++++++++++
 2 files changed, 23 insertions(+), 4 deletions(-)

diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py
index 1d730d86..37768572 100644
--- a/mmseg/models/backbones/vit.py
+++ b/mmseg/models/backbones/vit.py
@@ -234,6 +234,8 @@ class VisionTransformer(nn.Module):
             and its variants only. Default: False.
         final_norm (bool):  Whether to add a additional layer to normalize
             final feature map. Default: False.
+        out_reshape (str): Select the output format of feature information.
+            Default: NCHW.
         interpolate_mode (str): Select the interpolate mode for position
             embeding vector resize. Default: bicubic.
         with_cls_token (bool): If concatenating class token into image tokens
@@ -261,6 +263,7 @@ class VisionTransformer(nn.Module):
                  act_cfg=dict(type='GELU'),
                  norm_eval=False,
                  final_norm=False,
+                 out_shape='NCHW',
                  with_cls_token=True,
                  interpolate_mode='bicubic',
                  with_cp=False):
@@ -303,6 +306,11 @@ class VisionTransformer(nn.Module):
                 with_cp=with_cp) for i in range(depth)
         ])
 
+        assert out_shape in ['NLC',
+                             'NCHW'], 'output shape must be "NLC" or "NCHW".'
+
+        self.out_shape = out_shape
+
         self.interpolate_mode = interpolate_mode
         self.final_norm = final_norm
         if final_norm:
@@ -443,10 +451,11 @@ class VisionTransformer(nn.Module):
                     out = x[:, 1:]
                 else:
                     out = x
-                B, _, C = out.shape
-                out = out.reshape(B, inputs.shape[2] // self.patch_size,
-                                  inputs.shape[3] // self.patch_size,
-                                  C).permute(0, 3, 1, 2)
+                if self.out_shape == 'NCHW':
+                    B, _, C = out.shape
+                    out = out.reshape(B, inputs.shape[2] // self.patch_size,
+                                      inputs.shape[3] // self.patch_size,
+                                      C).permute(0, 3, 1, 2)
                 outs.append(out)
 
         return tuple(outs)
diff --git a/tests/test_models/test_backbones/test_vit.py b/tests/test_models/test_backbones/test_vit.py
index c36894ec..1ec42d34 100644
--- a/tests/test_models/test_backbones/test_vit.py
+++ b/tests/test_models/test_backbones/test_vit.py
@@ -30,6 +30,10 @@ def test_vit_backbone():
         model = VisionTransformer()
         model(x)
 
+    with pytest.raises(AssertionError):
+        # out_shape must be 'NLC' or 'NCHW;'
+        VisionTransformer(out_shape='NCL')
+
     # Test img_size isinstance int
     imgs = torch.randn(1, 3, 224, 224)
     model = VisionTransformer(img_size=224)
@@ -72,3 +76,9 @@ def test_vit_backbone():
     imgs = torch.randn(1, 3, 224, 224)
     feat = model(imgs)
     assert feat[-1].shape == (1, 768, 14, 14)
+
+    # Test final reshape arg
+    imgs = torch.randn(1, 3, 224, 224)
+    model = VisionTransformer(out_shape='NLC')
+    feat = model(imgs)
+    assert feat[-1].shape == (1, 196, 768)
-- 
GitLab