From 5c195db1bd1eb5f4903d7d39d5af4a82f6a435d4 Mon Sep 17 00:00:00 2001 From: sennnnn <58427300+sennnnn@users.noreply.github.com> Date: Thu, 6 May 2021 13:49:28 +0800 Subject: [PATCH] Add option for output shape of ViT (#530) * Add arg: final_reshape to control if converting output feature information from NLC to NCHW; * Fix the default value of final_reshape; * Modify arg: final_reshape to arg: out_shape; * Fix some unit test bug; --- mmseg/models/backbones/vit.py | 17 +++++++++++++---- tests/test_models/test_backbones/test_vit.py | 10 ++++++++++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py index 1d730d86..37768572 100644 --- a/mmseg/models/backbones/vit.py +++ b/mmseg/models/backbones/vit.py @@ -234,6 +234,8 @@ class VisionTransformer(nn.Module): and its variants only. Default: False. final_norm (bool): Whether to add a additional layer to normalize final feature map. Default: False. + out_reshape (str): Select the output format of feature information. + Default: NCHW. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Default: bicubic. with_cls_token (bool): If concatenating class token into image tokens @@ -261,6 +263,7 @@ class VisionTransformer(nn.Module): act_cfg=dict(type='GELU'), norm_eval=False, final_norm=False, + out_shape='NCHW', with_cls_token=True, interpolate_mode='bicubic', with_cp=False): @@ -303,6 +306,11 @@ class VisionTransformer(nn.Module): with_cp=with_cp) for i in range(depth) ]) + assert out_shape in ['NLC', + 'NCHW'], 'output shape must be "NLC" or "NCHW".' + + self.out_shape = out_shape + self.interpolate_mode = interpolate_mode self.final_norm = final_norm if final_norm: @@ -443,10 +451,11 @@ class VisionTransformer(nn.Module): out = x[:, 1:] else: out = x - B, _, C = out.shape - out = out.reshape(B, inputs.shape[2] // self.patch_size, - inputs.shape[3] // self.patch_size, - C).permute(0, 3, 1, 2) + if self.out_shape == 'NCHW': + B, _, C = out.shape + out = out.reshape(B, inputs.shape[2] // self.patch_size, + inputs.shape[3] // self.patch_size, + C).permute(0, 3, 1, 2) outs.append(out) return tuple(outs) diff --git a/tests/test_models/test_backbones/test_vit.py b/tests/test_models/test_backbones/test_vit.py index c36894ec..1ec42d34 100644 --- a/tests/test_models/test_backbones/test_vit.py +++ b/tests/test_models/test_backbones/test_vit.py @@ -30,6 +30,10 @@ def test_vit_backbone(): model = VisionTransformer() model(x) + with pytest.raises(AssertionError): + # out_shape must be 'NLC' or 'NCHW;' + VisionTransformer(out_shape='NCL') + # Test img_size isinstance int imgs = torch.randn(1, 3, 224, 224) model = VisionTransformer(img_size=224) @@ -72,3 +76,9 @@ def test_vit_backbone(): imgs = torch.randn(1, 3, 224, 224) feat = model(imgs) assert feat[-1].shape == (1, 768, 14, 14) + + # Test final reshape arg + imgs = torch.randn(1, 3, 224, 224) + model = VisionTransformer(out_shape='NLC') + feat = model(imgs) + assert feat[-1].shape == (1, 196, 768) -- GitLab