From 50461efe854bc922d1c345b8344a7c3aa59817aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20M=C3=A9ndez?= <miguelmndez@gmail.com> Date: Wed, 28 Jul 2021 10:56:22 +0200 Subject: [PATCH] [Fix] Replace interpolate with resize (#731) * Replace interpolate with resize * Replace nn.Upsample with ops.Upsample * Fix test --- mmseg/models/backbones/swin.py | 3 ++- mmseg/models/backbones/unet.py | 3 ++- mmseg/models/backbones/vit.py | 4 ++-- mmseg/models/decode_heads/fpn_head.py | 4 ++-- mmseg/models/decode_heads/setr_mla_head.py | 3 ++- mmseg/models/decode_heads/setr_up_head.py | 3 ++- mmseg/models/necks/fpn.py | 6 +++--- mmseg/models/necks/multilevel_neck.py | 4 ++-- tests/test_models/test_backbones/test_unet.py | 10 +++++----- tools/deploy_test.py | 5 +++-- tools/pytorch2onnx.py | 6 ++---- 11 files changed, 27 insertions(+), 24 deletions(-) diff --git a/mmseg/models/backbones/swin.py b/mmseg/models/backbones/swin.py index ef027dc0..68a989b5 100644 --- a/mmseg/models/backbones/swin.py +++ b/mmseg/models/backbones/swin.py @@ -13,6 +13,7 @@ from torch.nn.modules.linear import Linear from torch.nn.modules.normalization import LayerNorm from torch.nn.modules.utils import _pair as to_2tuple +from mmseg.ops import resize from ...utils import get_root_logger from ..builder import ATTENTION, BACKBONES from ..utils import PatchEmbed, swin_convert @@ -745,7 +746,7 @@ class SwinTransformer(BaseModule): if L1 != L2: S1 = int(L1**0.5) S2 = int(L2**0.5) - table_pretrained_resized = F.interpolate( + table_pretrained_resized = resize( table_pretrained.permute(1, 0).reshape( 1, nH1, S1, S1), size=(S2, S2), diff --git a/mmseg/models/backbones/unet.py b/mmseg/models/backbones/unet.py index a8cbe57f..705dd2b8 100644 --- a/mmseg/models/backbones/unet.py +++ b/mmseg/models/backbones/unet.py @@ -7,6 +7,7 @@ from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer, from mmcv.runner import BaseModule from mmcv.utils.parrots_wrapper import _BatchNorm +from mmseg.ops import Upsample from ..builder import BACKBONES from ..utils import UpConvBlock @@ -203,7 +204,7 @@ class InterpConv(nn.Module): conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) - upsample = nn.Upsample(**upsample_cfg) + upsample = Upsample(**upsample_cfg) if conv_first: self.interp_upsample = nn.Sequential(conv, upsample) else: diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py index 021bf093..e4f1839b 100644 --- a/mmseg/models/backbones/vit.py +++ b/mmseg/models/backbones/vit.py @@ -3,7 +3,6 @@ import warnings import torch import torch.nn as nn -import torch.nn.functional as F from mmcv.cnn import (build_norm_layer, constant_init, kaiming_init, normal_init, trunc_normal_init) from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention @@ -11,6 +10,7 @@ from mmcv.runner import BaseModule, ModuleList, _load_checkpoint from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.utils import _pair as to_2tuple +from mmseg.ops import resize from mmseg.utils import get_root_logger from ..builder import BACKBONES from ..utils import PatchEmbed, vit_convert @@ -373,7 +373,7 @@ class VisionTransformer(BaseModule): pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):] pos_embed_weight = pos_embed_weight.reshape( 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) - pos_embed_weight = F.interpolate( + pos_embed_weight = resize( pos_embed_weight, size=input_shpae, align_corners=False, mode=mode) cls_token_weight = cls_token_weight.unsqueeze(1) pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) diff --git a/mmseg/models/decode_heads/fpn_head.py b/mmseg/models/decode_heads/fpn_head.py index 9b6ada00..1e5bfd63 100644 --- a/mmseg/models/decode_heads/fpn_head.py +++ b/mmseg/models/decode_heads/fpn_head.py @@ -2,7 +2,7 @@ import numpy as np import torch.nn as nn from mmcv.cnn import ConvModule -from mmseg.ops import resize +from mmseg.ops import Upsample, resize from ..builder import HEADS from .decode_head import BaseDecodeHead @@ -45,7 +45,7 @@ class FPNHead(BaseDecodeHead): act_cfg=self.act_cfg)) if feature_strides[i] != feature_strides[0]: scale_head.append( - nn.Upsample( + Upsample( scale_factor=2, mode='bilinear', align_corners=self.align_corners)) diff --git a/mmseg/models/decode_heads/setr_mla_head.py b/mmseg/models/decode_heads/setr_mla_head.py index 016a82a4..86e493d2 100644 --- a/mmseg/models/decode_heads/setr_mla_head.py +++ b/mmseg/models/decode_heads/setr_mla_head.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn from mmcv.cnn import ConvModule +from mmseg.ops import Upsample from ..builder import HEADS from .decode_head import BaseDecodeHead @@ -46,7 +47,7 @@ class SETRMLAHead(BaseDecodeHead): padding=1, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg), - nn.Upsample( + Upsample( scale_factor=up_scale, mode='bilinear', align_corners=self.align_corners))) diff --git a/mmseg/models/decode_heads/setr_up_head.py b/mmseg/models/decode_heads/setr_up_head.py index 322a56dc..d64896f7 100644 --- a/mmseg/models/decode_heads/setr_up_head.py +++ b/mmseg/models/decode_heads/setr_up_head.py @@ -1,6 +1,7 @@ import torch.nn as nn from mmcv.cnn import ConvModule, build_norm_layer +from mmseg.ops import Upsample from ..builder import HEADS from .decode_head import BaseDecodeHead @@ -59,7 +60,7 @@ class SETRUPHead(BaseDecodeHead): padding=int(kernel_size - 1) // 2, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg), - nn.Upsample( + Upsample( scale_factor=up_scale, mode='bilinear', align_corners=self.align_corners))) diff --git a/mmseg/models/necks/fpn.py b/mmseg/models/necks/fpn.py index 4ba128ed..5e1bd218 100644 --- a/mmseg/models/necks/fpn.py +++ b/mmseg/models/necks/fpn.py @@ -3,6 +3,7 @@ import torch.nn.functional as F from mmcv.cnn import ConvModule from mmcv.runner import BaseModule, auto_fp16 +from mmseg.ops import resize from ..builder import NECKS @@ -173,11 +174,10 @@ class FPN(BaseModule): # In some cases, fixing `scale factor` (e.g. 2) is preferred, but # it cannot co-exist with `size` in `F.interpolate`. if 'scale_factor' in self.upsample_cfg: - laterals[i - 1] += F.interpolate(laterals[i], - **self.upsample_cfg) + laterals[i - 1] += resize(laterals[i], **self.upsample_cfg) else: prev_shape = laterals[i - 1].shape[2:] - laterals[i - 1] += F.interpolate( + laterals[i - 1] += resize( laterals[i], size=prev_shape, **self.upsample_cfg) # build outputs diff --git a/mmseg/models/necks/multilevel_neck.py b/mmseg/models/necks/multilevel_neck.py index eb32240b..9f638932 100644 --- a/mmseg/models/necks/multilevel_neck.py +++ b/mmseg/models/necks/multilevel_neck.py @@ -1,7 +1,7 @@ import torch.nn as nn -import torch.nn.functional as F from mmcv.cnn import ConvModule, xavier_init +from mmseg.ops import resize from ..builder import NECKS @@ -70,7 +70,7 @@ class MultiLevelNeck(nn.Module): inputs = [inputs[0] for _ in range(self.num_outs)] outs = [] for i in range(self.num_outs): - x_resize = F.interpolate( + x_resize = resize( inputs[i], scale_factor=self.scales[i], mode='bilinear') outs.append(self.convs[i](x_resize)) return tuple(outs) diff --git a/tests/test_models/test_backbones/test_unet.py b/tests/test_models/test_backbones/test_unet.py index defdf392..52f2123a 100644 --- a/tests/test_models/test_backbones/test_unet.py +++ b/tests/test_models/test_backbones/test_unet.py @@ -1,10 +1,10 @@ import pytest import torch from mmcv.cnn import ConvModule -from torch import nn from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule, InterpConv, UNet, UpConvBlock) +from mmseg.ops import Upsample from .utils import check_norm_state @@ -145,7 +145,7 @@ def test_interp_conv(): block = InterpConv(64, 32, conv_first=False) x = torch.randn(1, 64, 128, 128) x_out = block(x) - assert isinstance(block.interp_upsample[0], nn.Upsample) + assert isinstance(block.interp_upsample[0], Upsample) assert isinstance(block.interp_upsample[1], ConvModule) assert x_out.shape == torch.Size([1, 32, 256, 256]) @@ -154,7 +154,7 @@ def test_interp_conv(): x = torch.randn(1, 64, 128, 128) x_out = block(x) assert isinstance(block.interp_upsample[0], ConvModule) - assert isinstance(block.interp_upsample[1], nn.Upsample) + assert isinstance(block.interp_upsample[1], Upsample) assert x_out.shape == torch.Size([1, 32, 256, 256]) # test InterpConv with bilinear upsample for upsample 2X. @@ -166,7 +166,7 @@ def test_interp_conv(): scale_factor=2, mode='bilinear', align_corners=False)) x = torch.randn(1, 64, 128, 128) x_out = block(x) - assert isinstance(block.interp_upsample[0], nn.Upsample) + assert isinstance(block.interp_upsample[0], Upsample) assert isinstance(block.interp_upsample[1], ConvModule) assert x_out.shape == torch.Size([1, 32, 256, 256]) assert block.interp_upsample[0].mode == 'bilinear' @@ -179,7 +179,7 @@ def test_interp_conv(): upsample_cfg=dict(scale_factor=2, mode='nearest')) x = torch.randn(1, 64, 128, 128) x_out = block(x) - assert isinstance(block.interp_upsample[0], nn.Upsample) + assert isinstance(block.interp_upsample[0], Upsample) assert isinstance(block.interp_upsample[1], ConvModule) assert x_out.shape == torch.Size([1, 32, 256, 256]) assert block.interp_upsample[0].mode == 'nearest' diff --git a/tools/deploy_test.py b/tools/deploy_test.py index bef3512d..51f16b4a 100644 --- a/tools/deploy_test.py +++ b/tools/deploy_test.py @@ -14,6 +14,7 @@ from mmcv.utils import DictAction from mmseg.apis import single_gpu_test from mmseg.datasets import build_dataloader, build_dataset from mmseg.models.segmentors.base import BaseSegmentor +from mmseg.ops import resize class ONNXRuntimeSegmentor(BaseSegmentor): @@ -79,7 +80,7 @@ class ONNXRuntimeSegmentor(BaseSegmentor): if not (ori_shape[0] == seg_pred.shape[-2] and ori_shape[1] == seg_pred.shape[-1]): seg_pred = torch.from_numpy(seg_pred).float() - seg_pred = torch.nn.functional.interpolate( + seg_pred = resize( seg_pred, size=tuple(ori_shape[:2]), mode='nearest') seg_pred = seg_pred.long().detach().cpu().numpy() seg_pred = seg_pred[0] @@ -127,7 +128,7 @@ class TensorRTSegmentor(BaseSegmentor): if not (ori_shape[0] == seg_pred.shape[-2] and ori_shape[1] == seg_pred.shape[-1]): seg_pred = torch.from_numpy(seg_pred).float() - seg_pred = torch.nn.functional.interpolate( + seg_pred = resize( seg_pred, size=tuple(ori_shape[:2]), mode='nearest') seg_pred = seg_pred.long().detach().cpu().numpy() seg_pred = seg_pred[0] diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index 14f25056..17f10932 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -16,6 +16,7 @@ from mmseg.apis import show_result_pyplot from mmseg.apis.inference import LoadImage from mmseg.datasets.pipelines import Compose from mmseg.models import build_segmentor +from mmseg.ops import resize torch.manual_seed(3) @@ -210,10 +211,7 @@ def pytorch2onnx(model, if dynamic_export and test_mode == 'whole': # scale image for dynamic shape test - img_list = [ - nn.functional.interpolate(_, scale_factor=1.5) - for _ in img_list - ] + img_list = [resize(_, scale_factor=1.5) for _ in img_list] # concate flip image for batch test flip_img_list = [_.flip(-1) for _ in img_list] img_list = [ -- GitLab