From 50461efe854bc922d1c345b8344a7c3aa59817aa Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Miguel=20M=C3=A9ndez?= <miguelmndez@gmail.com>
Date: Wed, 28 Jul 2021 10:56:22 +0200
Subject: [PATCH] [Fix] Replace interpolate with resize (#731)

* Replace interpolate with resize

* Replace nn.Upsample with ops.Upsample

* Fix test
---
 mmseg/models/backbones/swin.py                |  3 ++-
 mmseg/models/backbones/unet.py                |  3 ++-
 mmseg/models/backbones/vit.py                 |  4 ++--
 mmseg/models/decode_heads/fpn_head.py         |  4 ++--
 mmseg/models/decode_heads/setr_mla_head.py    |  3 ++-
 mmseg/models/decode_heads/setr_up_head.py     |  3 ++-
 mmseg/models/necks/fpn.py                     |  6 +++---
 mmseg/models/necks/multilevel_neck.py         |  4 ++--
 tests/test_models/test_backbones/test_unet.py | 10 +++++-----
 tools/deploy_test.py                          |  5 +++--
 tools/pytorch2onnx.py                         |  6 ++----
 11 files changed, 27 insertions(+), 24 deletions(-)

diff --git a/mmseg/models/backbones/swin.py b/mmseg/models/backbones/swin.py
index ef027dc0..68a989b5 100644
--- a/mmseg/models/backbones/swin.py
+++ b/mmseg/models/backbones/swin.py
@@ -13,6 +13,7 @@ from torch.nn.modules.linear import Linear
 from torch.nn.modules.normalization import LayerNorm
 from torch.nn.modules.utils import _pair as to_2tuple
 
+from mmseg.ops import resize
 from ...utils import get_root_logger
 from ..builder import ATTENTION, BACKBONES
 from ..utils import PatchEmbed, swin_convert
@@ -745,7 +746,7 @@ class SwinTransformer(BaseModule):
                     if L1 != L2:
                         S1 = int(L1**0.5)
                         S2 = int(L2**0.5)
-                        table_pretrained_resized = F.interpolate(
+                        table_pretrained_resized = resize(
                             table_pretrained.permute(1, 0).reshape(
                                 1, nH1, S1, S1),
                             size=(S2, S2),
diff --git a/mmseg/models/backbones/unet.py b/mmseg/models/backbones/unet.py
index a8cbe57f..705dd2b8 100644
--- a/mmseg/models/backbones/unet.py
+++ b/mmseg/models/backbones/unet.py
@@ -7,6 +7,7 @@ from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
 from mmcv.runner import BaseModule
 from mmcv.utils.parrots_wrapper import _BatchNorm
 
+from mmseg.ops import Upsample
 from ..builder import BACKBONES
 from ..utils import UpConvBlock
 
@@ -203,7 +204,7 @@ class InterpConv(nn.Module):
             conv_cfg=conv_cfg,
             norm_cfg=norm_cfg,
             act_cfg=act_cfg)
-        upsample = nn.Upsample(**upsample_cfg)
+        upsample = Upsample(**upsample_cfg)
         if conv_first:
             self.interp_upsample = nn.Sequential(conv, upsample)
         else:
diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py
index 021bf093..e4f1839b 100644
--- a/mmseg/models/backbones/vit.py
+++ b/mmseg/models/backbones/vit.py
@@ -3,7 +3,6 @@ import warnings
 
 import torch
 import torch.nn as nn
-import torch.nn.functional as F
 from mmcv.cnn import (build_norm_layer, constant_init, kaiming_init,
                       normal_init, trunc_normal_init)
 from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
@@ -11,6 +10,7 @@ from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
 from torch.nn.modules.batchnorm import _BatchNorm
 from torch.nn.modules.utils import _pair as to_2tuple
 
+from mmseg.ops import resize
 from mmseg.utils import get_root_logger
 from ..builder import BACKBONES
 from ..utils import PatchEmbed, vit_convert
@@ -373,7 +373,7 @@ class VisionTransformer(BaseModule):
         pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
         pos_embed_weight = pos_embed_weight.reshape(
             1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
-        pos_embed_weight = F.interpolate(
+        pos_embed_weight = resize(
             pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
         cls_token_weight = cls_token_weight.unsqueeze(1)
         pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
diff --git a/mmseg/models/decode_heads/fpn_head.py b/mmseg/models/decode_heads/fpn_head.py
index 9b6ada00..1e5bfd63 100644
--- a/mmseg/models/decode_heads/fpn_head.py
+++ b/mmseg/models/decode_heads/fpn_head.py
@@ -2,7 +2,7 @@ import numpy as np
 import torch.nn as nn
 from mmcv.cnn import ConvModule
 
-from mmseg.ops import resize
+from mmseg.ops import Upsample, resize
 from ..builder import HEADS
 from .decode_head import BaseDecodeHead
 
@@ -45,7 +45,7 @@ class FPNHead(BaseDecodeHead):
                         act_cfg=self.act_cfg))
                 if feature_strides[i] != feature_strides[0]:
                     scale_head.append(
-                        nn.Upsample(
+                        Upsample(
                             scale_factor=2,
                             mode='bilinear',
                             align_corners=self.align_corners))
diff --git a/mmseg/models/decode_heads/setr_mla_head.py b/mmseg/models/decode_heads/setr_mla_head.py
index 016a82a4..86e493d2 100644
--- a/mmseg/models/decode_heads/setr_mla_head.py
+++ b/mmseg/models/decode_heads/setr_mla_head.py
@@ -2,6 +2,7 @@ import torch
 import torch.nn as nn
 from mmcv.cnn import ConvModule
 
+from mmseg.ops import Upsample
 from ..builder import HEADS
 from .decode_head import BaseDecodeHead
 
@@ -46,7 +47,7 @@ class SETRMLAHead(BaseDecodeHead):
                         padding=1,
                         norm_cfg=self.norm_cfg,
                         act_cfg=self.act_cfg),
-                    nn.Upsample(
+                    Upsample(
                         scale_factor=up_scale,
                         mode='bilinear',
                         align_corners=self.align_corners)))
diff --git a/mmseg/models/decode_heads/setr_up_head.py b/mmseg/models/decode_heads/setr_up_head.py
index 322a56dc..d64896f7 100644
--- a/mmseg/models/decode_heads/setr_up_head.py
+++ b/mmseg/models/decode_heads/setr_up_head.py
@@ -1,6 +1,7 @@
 import torch.nn as nn
 from mmcv.cnn import ConvModule, build_norm_layer
 
+from mmseg.ops import Upsample
 from ..builder import HEADS
 from .decode_head import BaseDecodeHead
 
@@ -59,7 +60,7 @@ class SETRUPHead(BaseDecodeHead):
                         padding=int(kernel_size - 1) // 2,
                         norm_cfg=self.norm_cfg,
                         act_cfg=self.act_cfg),
-                    nn.Upsample(
+                    Upsample(
                         scale_factor=up_scale,
                         mode='bilinear',
                         align_corners=self.align_corners)))
diff --git a/mmseg/models/necks/fpn.py b/mmseg/models/necks/fpn.py
index 4ba128ed..5e1bd218 100644
--- a/mmseg/models/necks/fpn.py
+++ b/mmseg/models/necks/fpn.py
@@ -3,6 +3,7 @@ import torch.nn.functional as F
 from mmcv.cnn import ConvModule
 from mmcv.runner import BaseModule, auto_fp16
 
+from mmseg.ops import resize
 from ..builder import NECKS
 
 
@@ -173,11 +174,10 @@ class FPN(BaseModule):
             # In some cases, fixing `scale factor` (e.g. 2) is preferred, but
             #  it cannot co-exist with `size` in `F.interpolate`.
             if 'scale_factor' in self.upsample_cfg:
-                laterals[i - 1] += F.interpolate(laterals[i],
-                                                 **self.upsample_cfg)
+                laterals[i - 1] += resize(laterals[i], **self.upsample_cfg)
             else:
                 prev_shape = laterals[i - 1].shape[2:]
-                laterals[i - 1] += F.interpolate(
+                laterals[i - 1] += resize(
                     laterals[i], size=prev_shape, **self.upsample_cfg)
 
         # build outputs
diff --git a/mmseg/models/necks/multilevel_neck.py b/mmseg/models/necks/multilevel_neck.py
index eb32240b..9f638932 100644
--- a/mmseg/models/necks/multilevel_neck.py
+++ b/mmseg/models/necks/multilevel_neck.py
@@ -1,7 +1,7 @@
 import torch.nn as nn
-import torch.nn.functional as F
 from mmcv.cnn import ConvModule, xavier_init
 
+from mmseg.ops import resize
 from ..builder import NECKS
 
 
@@ -70,7 +70,7 @@ class MultiLevelNeck(nn.Module):
             inputs = [inputs[0] for _ in range(self.num_outs)]
         outs = []
         for i in range(self.num_outs):
-            x_resize = F.interpolate(
+            x_resize = resize(
                 inputs[i], scale_factor=self.scales[i], mode='bilinear')
             outs.append(self.convs[i](x_resize))
         return tuple(outs)
diff --git a/tests/test_models/test_backbones/test_unet.py b/tests/test_models/test_backbones/test_unet.py
index defdf392..52f2123a 100644
--- a/tests/test_models/test_backbones/test_unet.py
+++ b/tests/test_models/test_backbones/test_unet.py
@@ -1,10 +1,10 @@
 import pytest
 import torch
 from mmcv.cnn import ConvModule
-from torch import nn
 
 from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule,
                                          InterpConv, UNet, UpConvBlock)
+from mmseg.ops import Upsample
 from .utils import check_norm_state
 
 
@@ -145,7 +145,7 @@ def test_interp_conv():
     block = InterpConv(64, 32, conv_first=False)
     x = torch.randn(1, 64, 128, 128)
     x_out = block(x)
-    assert isinstance(block.interp_upsample[0], nn.Upsample)
+    assert isinstance(block.interp_upsample[0], Upsample)
     assert isinstance(block.interp_upsample[1], ConvModule)
     assert x_out.shape == torch.Size([1, 32, 256, 256])
 
@@ -154,7 +154,7 @@ def test_interp_conv():
     x = torch.randn(1, 64, 128, 128)
     x_out = block(x)
     assert isinstance(block.interp_upsample[0], ConvModule)
-    assert isinstance(block.interp_upsample[1], nn.Upsample)
+    assert isinstance(block.interp_upsample[1], Upsample)
     assert x_out.shape == torch.Size([1, 32, 256, 256])
 
     # test InterpConv with bilinear upsample for upsample 2X.
@@ -166,7 +166,7 @@ def test_interp_conv():
             scale_factor=2, mode='bilinear', align_corners=False))
     x = torch.randn(1, 64, 128, 128)
     x_out = block(x)
-    assert isinstance(block.interp_upsample[0], nn.Upsample)
+    assert isinstance(block.interp_upsample[0], Upsample)
     assert isinstance(block.interp_upsample[1], ConvModule)
     assert x_out.shape == torch.Size([1, 32, 256, 256])
     assert block.interp_upsample[0].mode == 'bilinear'
@@ -179,7 +179,7 @@ def test_interp_conv():
         upsample_cfg=dict(scale_factor=2, mode='nearest'))
     x = torch.randn(1, 64, 128, 128)
     x_out = block(x)
-    assert isinstance(block.interp_upsample[0], nn.Upsample)
+    assert isinstance(block.interp_upsample[0], Upsample)
     assert isinstance(block.interp_upsample[1], ConvModule)
     assert x_out.shape == torch.Size([1, 32, 256, 256])
     assert block.interp_upsample[0].mode == 'nearest'
diff --git a/tools/deploy_test.py b/tools/deploy_test.py
index bef3512d..51f16b4a 100644
--- a/tools/deploy_test.py
+++ b/tools/deploy_test.py
@@ -14,6 +14,7 @@ from mmcv.utils import DictAction
 from mmseg.apis import single_gpu_test
 from mmseg.datasets import build_dataloader, build_dataset
 from mmseg.models.segmentors.base import BaseSegmentor
+from mmseg.ops import resize
 
 
 class ONNXRuntimeSegmentor(BaseSegmentor):
@@ -79,7 +80,7 @@ class ONNXRuntimeSegmentor(BaseSegmentor):
         if not (ori_shape[0] == seg_pred.shape[-2]
                 and ori_shape[1] == seg_pred.shape[-1]):
             seg_pred = torch.from_numpy(seg_pred).float()
-            seg_pred = torch.nn.functional.interpolate(
+            seg_pred = resize(
                 seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
             seg_pred = seg_pred.long().detach().cpu().numpy()
         seg_pred = seg_pred[0]
@@ -127,7 +128,7 @@ class TensorRTSegmentor(BaseSegmentor):
         if not (ori_shape[0] == seg_pred.shape[-2]
                 and ori_shape[1] == seg_pred.shape[-1]):
             seg_pred = torch.from_numpy(seg_pred).float()
-            seg_pred = torch.nn.functional.interpolate(
+            seg_pred = resize(
                 seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
             seg_pred = seg_pred.long().detach().cpu().numpy()
         seg_pred = seg_pred[0]
diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py
index 14f25056..17f10932 100644
--- a/tools/pytorch2onnx.py
+++ b/tools/pytorch2onnx.py
@@ -16,6 +16,7 @@ from mmseg.apis import show_result_pyplot
 from mmseg.apis.inference import LoadImage
 from mmseg.datasets.pipelines import Compose
 from mmseg.models import build_segmentor
+from mmseg.ops import resize
 
 torch.manual_seed(3)
 
@@ -210,10 +211,7 @@ def pytorch2onnx(model,
 
         if dynamic_export and test_mode == 'whole':
             # scale image for dynamic shape test
-            img_list = [
-                nn.functional.interpolate(_, scale_factor=1.5)
-                for _ in img_list
-            ]
+            img_list = [resize(_, scale_factor=1.5) for _ in img_list]
             # concate flip image for batch test
             flip_img_list = [_.flip(-1) for _ in img_list]
             img_list = [
-- 
GitLab