diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py index b140700a9b8b1e5b816abdafc2932c997b3b9a91..774f555c49841e0562ae1d70f5ead99fd87f2e45 100644 --- a/mmseg/models/backbones/vit.py +++ b/mmseg/models/backbones/vit.py @@ -1,294 +1,257 @@ -"""Modified from https://github.com/rwightman/pytorch-image- -models/blob/master/timm/models/vision_transformer.py.""" - import math import torch import torch.nn as nn import torch.nn.functional as F -import torch.utils.checkpoint as cp -from mmcv.cnn import (Conv2d, Linear, build_activation_layer, build_norm_layer, - constant_init, kaiming_init, normal_init) -from mmcv.runner import BaseModule, _load_checkpoint -from mmcv.utils.parrots_wrapper import _BatchNorm +from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init, + kaiming_init, normal_init, trunc_normal_init) +from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmcv.runner import _load_checkpoint +from mmcv.runner.base_module import BaseModule, ModuleList +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.modules.utils import _pair as to_2tuple from mmseg.utils import get_root_logger from ..builder import BACKBONES -from ..utils import DropPath, trunc_normal_ +from ..utils import vit_convert -class Mlp(nn.Module): - """MLP layer for Encoder block. +class TransformerEncoderLayer(BaseModule): + """Implements one encoder layer in Vision Transformer. Args: - in_features(int): Input dimension for the first fully - connected layer. - hidden_features(int): Output dimension for the first fully - connected layer. - out_features(int): Output dementsion for the second fully - connected layer. - act_cfg(dict): Config dict for activation layer. - Default: dict(type='GELU'). - drop(float): Drop rate for the dropout layer. Dropout rate has - to be between 0 and 1. Default: 0. + embed_dims (int): The feature dimension + num_heads (int): Parallel attention heads + feedforward_channels (int): The hidden dimension for FFNs + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default 0.0 + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. Default 2 + qkv_bias (bool): enable bias for qkv if True. Default True + act_cfg (dict): The activation config for FFNs. Defalut GELU + norm_cfg (dict): Config dict for normalization layer. Default + layer normalization + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default to False. + init_cfg (dict, optional): Initialization config dict """ def __init__(self, - in_features, - hidden_features=None, - out_features=None, - act_cfg=dict(type='GELU'), - drop=0.): - super(Mlp, self).__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = Linear(in_features, hidden_features) - self.act = build_activation_layer(act_cfg) - self.fc2 = Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -class Attention(nn.Module): - """Attention layer for Encoder block. - - Args: - dim (int): Dimension for the input vector. - num_heads (int): Number of parallel attention heads. - qkv_bias (bool): Enable bias for qkv if True. Default: False. - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. - attn_drop (float): Drop rate for attention output weights. - Default: 0. - proj_drop (float): Drop rate for output weights. Default: 0. - """ - - def __init__(self, - dim, - num_heads=8, - qkv_bias=False, - qk_scale=None, - attn_drop=0., - proj_drop=0.): - super(Attention, self).__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim**-0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x): - b, n, c = x.shape - qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, - c // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] - - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(b, n, c) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class Block(nn.Module): - """Implements encoder block with residual connection. - - Args: - dim (int): The feature dimension. - num_heads (int): Number of parallel attention heads. - mlp_ratio (int): Ratio of mlp hidden dim to embedding dim. - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. - drop (float): Drop rate for mlp output weights. Default: 0. - attn_drop (float): Drop rate for attention output weights. - Default: 0. - proj_drop (float): Drop rate for attn layer output weights. - Default: 0. - drop_path (float): Drop rate for paths of model. - Default: 0. - act_cfg (dict): Config dict for activation layer. - Default: dict(type='GELU'). - norm_cfg (dict): Config dict for normalization layer. - Default: dict(type='LN', requires_grad=True). - with_cp (bool): Use checkpoint or not. Using checkpoint will save some - memory while slowing down the training speed. Default: False. - """ - - def __init__(self, - dim, + embed_dims, num_heads, - mlp_ratio=4, - qkv_bias=False, - qk_scale=None, - drop=0., - attn_drop=0., - proj_drop=0., - drop_path=0., + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, act_cfg=dict(type='GELU'), - norm_cfg=dict(type='LN', eps=1e-6), - with_cp=False): - super(Block, self).__init__() - self.with_cp = with_cp - _, self.norm1 = build_norm_layer(norm_cfg, dim) - self.attn = Attention(dim, num_heads, qkv_bias, qk_scale, attn_drop, - proj_drop) - self.drop_path = DropPath( - drop_path) if drop_path > 0. else nn.Identity() - _, self.norm2 = build_norm_layer(norm_cfg, dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp( - in_features=dim, - hidden_features=mlp_hidden_dim, - act_cfg=act_cfg, - drop=drop) + norm_cfg=dict(type='LN'), + batch_first=False): + super(TransformerEncoderLayer, self).__init__() + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, embed_dims, postfix=1) + self.add_module(self.norm1_name, norm1) + + self.attn = MultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + batch_first=batch_first, + bias=qkv_bias) + + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + self.add_module(self.norm2_name, norm2) + + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=None, + act_cfg=act_cfg) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + @property + def norm2(self): + return getattr(self, self.norm2_name) def forward(self, x): - - def _inner_forward(x): - out = x + self.drop_path(self.attn(self.norm1(x))) - out = out + self.drop_path(self.mlp(self.norm2(out))) - return out - - if self.with_cp and x.requires_grad: - out = cp.checkpoint(_inner_forward, x) - else: - out = _inner_forward(x) - - return out + x = self.attn(self.norm1(x), identity=x) + x = self.ffn(self.norm2(x), identity=x) + return x -class PatchEmbed(nn.Module): +# Modified from pytorch-image-models +class PatchEmbed(BaseModule): """Image to Patch Embedding. Args: - img_size (int | tuple): Input image size. - default: 224. - patch_size (int): Width and height for a patch. - default: 16. - in_channels (int): Input channels for images. Default: 3. - embed_dim (int): The embedding dimension. Default: 768. + img_size (int | tuple): The size of input image. + patch_size (int): The size of one patch + in_channels (int): The num of input channels. + embed_dim (int): The dimensions of embedding. + norm_cfg (dict, optional): Config dict for normalization layer. + conv_cfg (dict, optional): The config dict for conv layers. + Default: None. """ def __init__(self, img_size=224, patch_size=16, in_channels=3, - embed_dim=768): + embed_dim=768, + norm_cfg=None, + conv_cfg=None): super(PatchEmbed, self).__init__() - if isinstance(img_size, int): - self.img_size = (img_size, img_size) - elif isinstance(img_size, tuple): - self.img_size = img_size + + self.img_size = img_size + self.patch_size = to_2tuple(patch_size) + + patches_resolution = [ + img_size[0] // self.patch_size[0], + img_size[1] // self.patch_size[1] + ] + num_patches = patches_resolution[0] * patches_resolution[1] + self.patches_resolution = patches_resolution + self.num_patches = num_patches + + # Use conv layer to embed + self.projection = build_conv_layer( + conv_cfg, + in_channels, + embed_dim, + kernel_size=patch_size, + stride=patch_size) + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, embed_dim)[1] else: - raise TypeError('img_size must be type of int or tuple') - h, w = self.img_size - self.patch_size = (patch_size, patch_size) - self.num_patches = (h // patch_size) * (w // patch_size) - self.proj = Conv2d( - in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = None def forward(self, x): - return self.proj(x).flatten(2).transpose(1, 2) + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't " \ + # f'match model ({self.img_size[0]}*{self.img_size[1]}).' + # The output size is (B, N, D), where N=H*W/P/P, D is embid_dim + x = self.projection(x).flatten(2).transpose(1, 2) + + if self.norm is not None: + x = self.norm(x) + + return x @BACKBONES.register_module() class VisionTransformer(BaseModule): - """Vision transformer backbone. + """Vision Transformer. - A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for - Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 + A PyTorch implement of : `An Image is Worth 16x16 Words: + Transformers for Image Recognition at Scale` - + https://arxiv.org/abs/2010.11929 Args: - img_size (tuple): input image size. Default: (224, 224). - patch_size (int, tuple): patch size. Default: 16. - in_channels (int): number of input channels. Default: 3. - embed_dim (int): embedding dimension. Default: 768. - depth (int): depth of transformer. Default: 12. + img_size (int | tuple): Input image size. Default: 224. + patch_size (int): The patch size. Default: 16. + in_channels (int): Number of input channels. Default: 3. + embed_dims (int): embedding dimension. Default: 768. + num_layers (int): depth of transformer. Default: 12. num_heads (int): number of attention heads. Default: 12. mlp_ratio (int): ratio of mlp hidden dim to embedding dim. Default: 4. out_indices (list | tuple | int): Output from which stages. Default: -1. qkv_bias (bool): enable bias for qkv if True. Default: True. - qk_scale (float): override default qk scale of head_dim ** -0.5 if set. - drop_rate (float): dropout rate. Default: 0. - attn_drop_rate (float): attention dropout rate. Default: 0. - drop_path_rate (float): Rate of DropPath. Default: 0. + drop_rate (float): Probability of an element to be zeroed. + Default 0.0 + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): stochastic depth rate. Default 0.0 + with_cls_token (bool): If concatenating class token into image tokens + as transformer input. Default: True. norm_cfg (dict): Config dict for normalization layer. - Default: dict(type='LN', eps=1e-6, requires_grad=True). - act_cfg (dict): Config dict for activation layer. - Default: dict(type='GELU'). - norm_eval (bool): Whether to set norm layers to eval mode, namely, - freeze running stats (mean and var). Note: Effect on Batch Norm - and its variants only. Default: False. + Default: dict(type='LN') + act_cfg (dict): The activation config for FFNs. + Defalut: dict(type='GELU'). final_norm (bool): Whether to add a additional layer to normalize final feature map. Default: False. - out_reshape (str): Select the output format of feature information. - Default: NCHW. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Default: bicubic. - with_cls_token (bool): If concatenating class token into image tokens - as transformer input. Default: True. - with_cp (bool): Use checkpoint or not. Using checkpoint - will save some memory while slowing down the training speed. - Default: False. - pretrained (str, optional): model pretrained path. Default: None - init_cfg (dict or list[dict], optional): Initialization config dict. - Default: None + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. Default: False. + pretrain_style (str): Choose to use timm or mmcls pretrain weights. + Default: timm. """ def __init__(self, - img_size=(224, 224), + img_size=224, patch_size=16, in_channels=3, - embed_dim=768, - depth=12, + embed_dims=768, + num_layers=12, num_heads=12, mlp_ratio=4, out_indices=11, qkv_bias=True, - qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., - norm_cfg=dict(type='LN', eps=1e-6, requires_grad=True), + with_cls_token=True, + norm_cfg=dict(type='LN'), act_cfg=dict(type='GELU'), - norm_eval=False, final_norm=False, - out_shape='NCHW', - with_cls_token=True, interpolate_mode='bicubic', + num_fcs=2, + norm_eval=False, with_cp=False, - pretrained=None, - init_cfg=None): - super(VisionTransformer, self).__init__(init_cfg) - self.pretrained = pretrained + pretrain_style='timm'): + super(VisionTransformer, self).__init__() + + if isinstance(img_size, int): + img_size = to_2tuple(img_size) + elif isinstance(img_size, tuple): + if len(img_size) == 1: + img_size = to_2tuple(img_size[0]) + assert len(img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(img_size)}' + assert pretrain_style in ['timm', 'mmcls'] + + self.pretrain_style = pretrain_style self.img_size = img_size self.patch_size = patch_size - self.features = self.embed_dim = embed_dim + self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_channels=in_channels, - embed_dim=embed_dim) + embed_dim=embed_dims, + norm_cfg=norm_cfg) + num_patches = self.patch_embed.num_patches self.with_cls_token = with_cls_token - self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) self.pos_embed = nn.Parameter( - torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim)) - self.pos_drop = nn.Dropout(p=drop_rate) + torch.zeros(1, num_patches + 1, embed_dims)) + self.drop_after_pos = nn.Dropout(p=drop_rate) if isinstance(out_indices, int): self.out_indices = [out_indices] @@ -297,37 +260,41 @@ class VisionTransformer(BaseModule): else: raise TypeError('out_indices must be type of int, list or tuple') - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) - ] # stochastic depth decay rule - self.blocks = nn.ModuleList([ - Block( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=dpr[i], - attn_drop=attn_drop_rate, - act_cfg=act_cfg, - norm_cfg=norm_cfg, - with_cp=with_cp) for i in range(depth) - ]) - - assert out_shape in ['NLC', - 'NCHW'], 'output shape must be "NLC" or "NCHW".' - - self.out_shape = out_shape + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, num_layers) + ] # stochastic depth decay rule + + self.layers = ModuleList() + for i in range(num_layers): + self.layers.append( + TransformerEncoderLayer( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=mlp_ratio * embed_dims, + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + drop_path_rate=dpr[i], + num_fcs=num_fcs, + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + batch_first=True)) self.interpolate_mode = interpolate_mode self.final_norm = final_norm if final_norm: - _, self.norm = build_norm_layer(norm_cfg, embed_dim) + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, embed_dims, postfix=1) + self.add_module(self.norm1_name, norm1) self.norm_eval = norm_eval self.with_cp = with_cp - def init_weights(self): - pretrained = self.pretrained + @property + def norm1(self): + return getattr(self, self.norm1_name) + + def init_weights(self, pretrained=None): if isinstance(pretrained, str): logger = get_root_logger() checkpoint = _load_checkpoint(pretrained, logger=logger) @@ -338,10 +305,17 @@ class VisionTransformer(BaseModule): else: state_dict = checkpoint + if self.pretrain_style == 'timm': + # Because the refactor of vit is blocked by mmcls, + # so we firstly use timm pretrain weights to train + # downstream model. + state_dict = vit_convert(state_dict) + if 'pos_embed' in state_dict.keys(): if self.pos_embed.shape != state_dict['pos_embed'].shape: - logger.info(msg=f'Resize the pos_embed shape from \ -{state_dict["pos_embed"].shape} to {self.pos_embed.shape}') + logger.info(msg=f'Resize the pos_embed shape from ' + f'{state_dict["pos_embed"].shape} to ' + f'{self.pos_embed.shape}') h, w = self.img_size pos_size = int( math.sqrt(state_dict['pos_embed'].shape[1] - 1)) @@ -354,17 +328,17 @@ class VisionTransformer(BaseModule): elif pretrained is None: # We only implement the 'jax_impl' initialization implemented at # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 - trunc_normal_(self.pos_embed, std=.02) - trunc_normal_(self.cls_token, std=.02) + trunc_normal_init(self.pos_embed, std=.02) + trunc_normal_init(self.cls_token, std=.02) for n, m in self.named_modules(): - if isinstance(m, Linear): - trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear): + trunc_normal_init(m.weight, std=.02) if m.bias is not None: - if 'mlp' in n: + if 'ffn' in n: normal_init(m.bias, std=1e-6) else: constant_init(m.bias, 0) - elif isinstance(m, Conv2d): + elif isinstance(m, nn.Conv2d): kaiming_init(m.weight, mode='fan_in') if m.bias is not None: constant_init(m.bias, 0) @@ -404,7 +378,7 @@ class VisionTransformer(BaseModule): pos_embed = self.resize_pos_embed(pos_embed, img.shape[2:], (pos_h, pos_w), self.patch_size, self.interpolate_mode) - return self.pos_drop(patched_img + pos_embed) + return self.drop_after_pos(patched_img + pos_embed) @staticmethod def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode): @@ -441,31 +415,31 @@ class VisionTransformer(BaseModule): x = self.patch_embed(inputs) + # stole cls_tokens impl from Phil Wang, thanks cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) x = self._pos_embeding(inputs, x, self.pos_embed) if not self.with_cls_token: - # Remove class token for transformer input + # Remove class token for transformer encoder input x = x[:, 1:] outs = [] - for i, blk in enumerate(self.blocks): - x = blk(x) - if i == len(self.blocks) - 1: + for i, layer in enumerate(self.layers): + x = layer(x) + if i == len(self.layers) - 1: if self.final_norm: - x = self.norm(x) + x = self.norm1(x) if i in self.out_indices: if self.with_cls_token: # Remove class token and reshape token for decoder head out = x[:, 1:] else: out = x - if self.out_shape == 'NCHW': - B, _, C = out.shape - out = out.reshape(B, inputs.shape[2] // self.patch_size, - inputs.shape[3] // self.patch_size, - C).permute(0, 3, 1, 2) + B, _, C = out.shape + out = out.reshape(B, inputs.shape[2] // self.patch_size, + inputs.shape[3] // self.patch_size, + C).permute(0, 3, 1, 2) outs.append(out) return tuple(outs) diff --git a/mmseg/models/utils/__init__.py b/mmseg/models/utils/__init__.py index 3d3bdd349b9f2ae499a2fcb2ac1d2e3c77befebe..be11d77f4e7b193d8fc006a8672fe50059468ee6 100644 --- a/mmseg/models/utils/__init__.py +++ b/mmseg/models/utils/__init__.py @@ -4,10 +4,10 @@ from .make_divisible import make_divisible from .res_layer import ResLayer from .se_layer import SELayer from .self_attention_block import SelfAttentionBlock +from .timm_convert import vit_convert from .up_conv_block import UpConvBlock -from .weight_init import trunc_normal_ __all__ = [ 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', - 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'DropPath', 'trunc_normal_' + 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'DropPath', 'vit_convert' ] diff --git a/mmseg/models/utils/timm_convert.py b/mmseg/models/utils/timm_convert.py new file mode 100644 index 0000000000000000000000000000000000000000..f9a4d311090339aba71fd79271542c0ab76bbbc9 --- /dev/null +++ b/mmseg/models/utils/timm_convert.py @@ -0,0 +1,33 @@ +from collections import OrderedDict + + +def vit_convert(timm_dict): + + mmseg_dict = OrderedDict() + + for k, v in timm_dict.items(): + if k.startswith('head'): + continue + if k.startswith('norm'): + new_k = k.replace('norm.', 'ln1.') + elif k.startswith('patch_embed'): + if 'proj' in k: + new_k = k.replace('proj', 'projection') + elif k.startswith('blocks'): + new_k = k.replace('blocks.', 'layers.') + if 'norm' in new_k: + new_k = new_k.replace('norm', 'ln') + elif 'mlp.fc1' in new_k: + new_k = new_k.replace('mlp.fc1', 'ffn.layers.0.0') + elif 'mlp.fc2' in new_k: + new_k = new_k.replace('mlp.fc2', 'ffn.layers.1') + elif 'attn.qkv' in new_k: + new_k = new_k.replace('attn.qkv.', 'attn.attn.in_proj_') + elif 'attn.proj' in new_k: + new_k = new_k.replace('attn.proj', 'attn.attn.out_proj') + else: + new_k = k + new_k = f'backbone.{new_k}' + mmseg_dict[new_k] = v + + return mmseg_dict diff --git a/mmseg/models/utils/weight_init.py b/mmseg/models/utils/weight_init.py deleted file mode 100644 index 38141ba3d61f64ddfc0a31574b4648cbad96d7dd..0000000000000000000000000000000000000000 --- a/mmseg/models/utils/weight_init.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Modified from https://github.com/rwightman/pytorch-image- -models/blob/master/timm/models/layers/drop.py.""" - -import math -import warnings - -import torch - - -def _no_grad_trunc_normal_(tensor, mean, std, a, b): - """Reference: https://people.sc.fsu.edu/~jburkardt/presentations - /truncated_normal.pdf""" - - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1. + math.erf(x / math.sqrt(2.))) / 2. - - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn( - 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' - 'The distribution of values may be incorrect.', - stacklevel=2) - - with torch.no_grad(): - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - lower_bound = norm_cdf((a - mean) / std) - upper_bound = norm_cdf((b - mean) / std) - - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * lower_bound - 1, 2 * upper_bound - 1) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - return tensor - - -def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): - r"""Fills the input Tensor with values drawn from a truncated - normal distribution. The values are effectively drawn from the - normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` - with values outside :math:`[a, b]` redrawn until they are within - the bounds. The method used for generating the random values works - best when :math:`a \leq \text{mean} \leq b`. - Args: - tensor (``torch.Tensor``): an n-dimensional `torch.Tensor` - mean (float): the mean of the normal distribution - std (float): the standard deviation of the normal distribution - a (float): the minimum cutoff value - b (float): the maximum cutoff value - """ - return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/tests/test_models/test_backbones/test_vit.py b/tests/test_models/test_backbones/test_vit.py index 1ec42d34ea4a80a0b46aa2d2e192ed48248c5926..452eee05d8ccb4264c0810b4d399f0734888ce30 100644 --- a/tests/test_models/test_backbones/test_vit.py +++ b/tests/test_models/test_backbones/test_vit.py @@ -24,19 +24,18 @@ def test_vit_backbone(): x = torch.randn(1, 196) VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear') - with pytest.raises(RuntimeError): + with pytest.raises(ValueError): # forward inputs must be [N, C, H, W] x = torch.randn(3, 30, 30) model = VisionTransformer() model(x) with pytest.raises(AssertionError): - # out_shape must be 'NLC' or 'NCHW;' - VisionTransformer(out_shape='NCL') + VisionTransformer(img_size=(224, 224, 224)) - # Test img_size isinstance int + # Test img_size isinstance tuple imgs = torch.randn(1, 3, 224, 224) - model = VisionTransformer(img_size=224) + model = VisionTransformer(img_size=(224, 224)) model.init_weights() model(imgs) @@ -65,6 +64,11 @@ def test_vit_backbone(): feat = model(imgs) assert feat[-1].shape == (1, 768, 14, 14) + # Test unbalanced size input image + imgs = torch.randn(1, 3, 112, 224) + feat = model(imgs) + assert feat[-1].shape == (1, 768, 7, 14) + # Test with_cp=True model = VisionTransformer(with_cp=True) imgs = torch.randn(1, 3, 224, 224) @@ -77,8 +81,8 @@ def test_vit_backbone(): feat = model(imgs) assert feat[-1].shape == (1, 768, 14, 14) - # Test final reshape arg + # Test final norm + model = VisionTransformer(final_norm=True) imgs = torch.randn(1, 3, 224, 224) - model = VisionTransformer(out_shape='NLC') feat = model(imgs) - assert feat[-1].shape == (1, 196, 768) + assert feat[-1].shape == (1, 768, 14, 14)