diff --git a/configs/_base_/models/upernet_vit-b16_ln_mln.py b/configs/_base_/models/upernet_vit-b16_ln_mln.py index 573612e13a63c5a7c635221435807e230b007a3a..1a5a56972908d018cadceb46d5e7efa2c109f6d2 100644 --- a/configs/_base_/models/upernet_vit-b16_ln_mln.py +++ b/configs/_base_/models/upernet_vit-b16_ln_mln.py @@ -21,7 +21,6 @@ model = dict( norm_cfg=dict(type='LN', eps=1e-6), act_cfg=dict(type='GELU'), norm_eval=False, - out_shape='NCHW', interpolate_mode='bicubic'), neck=dict( type='MultiLevelNeck', diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py index 1ad20a1ca6ac18b8583fbc7c3b419df4d33477ed..33176351ea00cd37c72e8e2c2323419c7369c183 100644 --- a/mmseg/models/backbones/vit.py +++ b/mmseg/models/backbones/vit.py @@ -118,8 +118,10 @@ class VisionTransformer(BaseModule): attn_drop_rate (float): The drop out rate for attention layer. Default 0.0 drop_path_rate (float): stochastic depth rate. Default 0.0 - with_cls_token (bool): If concatenating class token into image tokens - as transformer input. Default: True. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Default: True. + output_cls_token (bool): Whether output the cls_token. If set True, + `with_cls_token` must be True. Default: False. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='LN') act_cfg (dict): The activation config for FFNs. @@ -128,8 +130,6 @@ class VisionTransformer(BaseModule): Default: False. final_norm (bool): Whether to add a additional layer to normalize final feature map. Default: False. - out_shape (str): Select the output format of feature information. - Default: NCHW. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Default: bicubic. num_fcs (int): The number of fully-connected layers for FFNs. @@ -160,11 +160,11 @@ class VisionTransformer(BaseModule): attn_drop_rate=0., drop_path_rate=0., with_cls_token=True, + output_cls_token=False, norm_cfg=dict(type='LN'), act_cfg=dict(type='GELU'), patch_norm=False, final_norm=False, - out_shape='NCHW', interpolate_mode='bicubic', num_fcs=2, norm_eval=False, @@ -185,8 +185,9 @@ class VisionTransformer(BaseModule): assert pretrain_style in ['timm', 'mmcls'] - assert out_shape in ['NLC', - 'NCHW'], 'output shape must be "NLC" or "NCHW".' + if output_cls_token: + assert with_cls_token is True, f'with_cls_token must be True if' \ + f'set output_cls_token to True, but got {with_cls_token}' if isinstance(pretrained, str) or pretrained is None: warnings.warn('DeprecationWarning: pretrained is a deprecated, ' @@ -196,7 +197,6 @@ class VisionTransformer(BaseModule): self.img_size = img_size self.patch_size = patch_size - self.out_shape = out_shape self.interpolate_mode = interpolate_mode self.norm_eval = norm_eval self.with_cp = with_cp @@ -218,6 +218,7 @@ class VisionTransformer(BaseModule): (img_size[1] // patch_size) self.with_cls_token = with_cls_token + self.output_cls_token = output_cls_token self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) self.pos_embed = nn.Parameter( torch.zeros(1, num_patches + 1, embed_dims)) @@ -253,7 +254,6 @@ class VisionTransformer(BaseModule): batch_first=True)) self.final_norm = final_norm - self.out_shape = out_shape if final_norm: self.norm1_name, norm1 = build_norm_layer( norm_cfg, embed_dims, postfix=1) @@ -290,8 +290,9 @@ class VisionTransformer(BaseModule): pos_size = int( math.sqrt(state_dict['pos_embed'].shape[1] - 1)) state_dict['pos_embed'] = self.resize_pos_embed( - state_dict['pos_embed'], (h, w), (pos_size, pos_size), - self.patch_size, self.interpolate_mode) + state_dict['pos_embed'], + (h // self.patch_size, w // self.patch_size), + (pos_size, pos_size), self.interpolate_mode) self.load_state_dict(state_dict, False) @@ -317,16 +318,15 @@ class VisionTransformer(BaseModule): constant_init(m.bias, 0) constant_init(m.weight, 1.0) - def _pos_embeding(self, img, patched_img, pos_embed): + def _pos_embeding(self, patched_img, hw_shape, pos_embed): """Positiong embeding method. Resize the pos_embed, if the input image size doesn't match the training size. Args: - img (torch.Tensor): The inference image tensor, the shape - must be [B, C, H, W]. patched_img (torch.Tensor): The patched image, it should be shape of [B, L1, C]. + hw_shape (tuple): The downsampled image resolution. pos_embed (torch.Tensor): The pos_embed weighs, it should be shape of [B, L2, c]. Return: @@ -344,36 +344,36 @@ class VisionTransformer(BaseModule): raise ValueError( 'Unexpected shape of pos_embed, got {}.'.format( pos_embed.shape)) - pos_embed = self.resize_pos_embed(pos_embed, img.shape[2:], - (pos_h, pos_w), self.patch_size, + pos_embed = self.resize_pos_embed(pos_embed, hw_shape, + (pos_h, pos_w), self.interpolate_mode) return self.drop_after_pos(patched_img + pos_embed) @staticmethod - def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode): + def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode): """Resize pos_embed weights. Resize pos_embed using bicubic interpolate method. Args: - pos_embed (torch.Tensor): pos_embed weights. - input_shpae (tuple): Tuple for (input_h, intput_w). - pos_shape (tuple): Tuple for (pos_h, pos_w). - patch_size (int): Patch size. + pos_embed (torch.Tensor): Position embedding weights. + input_shpae (tuple): Tuple for (downsampled input image height, + downsampled input image width). + pos_shape (tuple): The resolution of downsampled origin training + image. + mode (str): Algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'``. Default: ``'nearest'`` Return: torch.Tensor: The resized pos_embed of shape [B, L_new, C] """ assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]' - input_h, input_w = input_shpae pos_h, pos_w = pos_shape cls_token_weight = pos_embed[:, 0] pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):] pos_embed_weight = pos_embed_weight.reshape( 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) pos_embed_weight = F.interpolate( - pos_embed_weight, - size=[input_h // patch_size, input_w // patch_size], - align_corners=False, - mode=mode) + pos_embed_weight, size=input_shpae, align_corners=False, mode=mode) cls_token_weight = cls_token_weight.unsqueeze(1) pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1) @@ -382,12 +382,12 @@ class VisionTransformer(BaseModule): def forward(self, inputs): B = inputs.shape[0] - x = self.patch_embed(inputs) - + x, hw_shape = self.patch_embed(inputs), (self.patch_embed.DH, + self.patch_embed.DW) # stole cls_tokens impl from Phil Wang, thanks cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) - x = self._pos_embeding(inputs, x, self.pos_embed) + x = self._pos_embeding(x, hw_shape, self.pos_embed) if not self.with_cls_token: # Remove class token for transformer encoder input @@ -405,11 +405,11 @@ class VisionTransformer(BaseModule): out = x[:, 1:] else: out = x - if self.out_shape == 'NCHW': - B, _, C = out.shape - out = out.reshape(B, inputs.shape[2] // self.patch_size, - inputs.shape[3] // self.patch_size, - C).permute(0, 3, 1, 2) + B, _, C = out.shape + out = out.reshape(B, hw_shape[0], hw_shape[1], + C).permute(0, 3, 1, 2) + if self.output_cls_token: + out = [out, x[:, 0]] outs.append(out) return tuple(outs) diff --git a/tests/test_models/test_backbones/test_vit.py b/tests/test_models/test_backbones/test_vit.py index 4577b97b86bb29f14d1bb72deb42c879f80e5192..16d6aba68fca014a1723fcc01dceec96b441d332 100644 --- a/tests/test_models/test_backbones/test_vit.py +++ b/tests/test_models/test_backbones/test_vit.py @@ -39,8 +39,8 @@ def test_vit_backbone(): VisionTransformer(pretrained=123) with pytest.raises(AssertionError): - # out_shape must be 'NLC' or 'NCHW;' - VisionTransformer(out_shape='NCL') + # with_cls_token must be True when output_cls_token == True + VisionTransformer(with_cls_token=False, output_cls_token=True) # Test img_size isinstance tuple imgs = torch.randn(1, 3, 224, 224) @@ -88,6 +88,11 @@ def test_vit_backbone(): feat = model(imgs) assert feat[-1].shape == (1, 768, 7, 14) + # Test irregular input image + imgs = torch.randn(1, 3, 234, 345) + feat = model(imgs) + assert feat[-1].shape == (1, 768, 15, 22) + # Test with_cp=True model = VisionTransformer(with_cp=True) imgs = torch.randn(1, 3, 224, 224) @@ -100,12 +105,6 @@ def test_vit_backbone(): feat = model(imgs) assert feat[-1].shape == (1, 768, 14, 14) - # Test out_shape == 'NLC' - model = VisionTransformer(out_shape='NLC') - imgs = torch.randn(1, 3, 224, 224) - feat = model(imgs) - assert feat[-1].shape == (1, 196, 768) - # Test final norm model = VisionTransformer(final_norm=True) imgs = torch.randn(1, 3, 224, 224) @@ -117,3 +116,10 @@ def test_vit_backbone(): imgs = torch.randn(1, 3, 224, 224) feat = model(imgs) assert feat[-1].shape == (1, 768, 14, 14) + + # Test output_cls_token + model = VisionTransformer(with_cls_token=True, output_cls_token=True) + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert feat[0][0].shape == (1, 768, 14, 14) + assert feat[0][1].shape == (1, 768)