diff --git a/mmseg/models/backbones/swin.py b/mmseg/models/backbones/swin.py
index ef027dc0d95b69527ad77af1e4479097490207b3..68a989b5d77b96c8030081ddcd97c07bb93fb9ab 100644
--- a/mmseg/models/backbones/swin.py
+++ b/mmseg/models/backbones/swin.py
@@ -13,6 +13,7 @@ from torch.nn.modules.linear import Linear
 from torch.nn.modules.normalization import LayerNorm
 from torch.nn.modules.utils import _pair as to_2tuple
 
+from mmseg.ops import resize
 from ...utils import get_root_logger
 from ..builder import ATTENTION, BACKBONES
 from ..utils import PatchEmbed, swin_convert
@@ -745,7 +746,7 @@ class SwinTransformer(BaseModule):
                     if L1 != L2:
                         S1 = int(L1**0.5)
                         S2 = int(L2**0.5)
-                        table_pretrained_resized = F.interpolate(
+                        table_pretrained_resized = resize(
                             table_pretrained.permute(1, 0).reshape(
                                 1, nH1, S1, S1),
                             size=(S2, S2),
diff --git a/mmseg/models/backbones/unet.py b/mmseg/models/backbones/unet.py
index a8cbe57f6cb0c226f2da1e9b34010fa641b8a3e3..705dd2b8f8ed62f9928a048bcbf3ec7e722115b3 100644
--- a/mmseg/models/backbones/unet.py
+++ b/mmseg/models/backbones/unet.py
@@ -7,6 +7,7 @@ from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
 from mmcv.runner import BaseModule
 from mmcv.utils.parrots_wrapper import _BatchNorm
 
+from mmseg.ops import Upsample
 from ..builder import BACKBONES
 from ..utils import UpConvBlock
 
@@ -203,7 +204,7 @@ class InterpConv(nn.Module):
             conv_cfg=conv_cfg,
             norm_cfg=norm_cfg,
             act_cfg=act_cfg)
-        upsample = nn.Upsample(**upsample_cfg)
+        upsample = Upsample(**upsample_cfg)
         if conv_first:
             self.interp_upsample = nn.Sequential(conv, upsample)
         else:
diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py
index 021bf09331dc9e2b442bf3bcb8274b67dbb91243..e4f1839bdb8c270eff7ae30c73f4f7351a3b4f7d 100644
--- a/mmseg/models/backbones/vit.py
+++ b/mmseg/models/backbones/vit.py
@@ -3,7 +3,6 @@ import warnings
 
 import torch
 import torch.nn as nn
-import torch.nn.functional as F
 from mmcv.cnn import (build_norm_layer, constant_init, kaiming_init,
                       normal_init, trunc_normal_init)
 from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
@@ -11,6 +10,7 @@ from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
 from torch.nn.modules.batchnorm import _BatchNorm
 from torch.nn.modules.utils import _pair as to_2tuple
 
+from mmseg.ops import resize
 from mmseg.utils import get_root_logger
 from ..builder import BACKBONES
 from ..utils import PatchEmbed, vit_convert
@@ -373,7 +373,7 @@ class VisionTransformer(BaseModule):
         pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
         pos_embed_weight = pos_embed_weight.reshape(
             1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
-        pos_embed_weight = F.interpolate(
+        pos_embed_weight = resize(
             pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
         cls_token_weight = cls_token_weight.unsqueeze(1)
         pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
diff --git a/mmseg/models/decode_heads/fpn_head.py b/mmseg/models/decode_heads/fpn_head.py
index 9b6ada005995cc7e7f0b0cd66f4880d6a2bd665b..1e5bfd63fc0ec9a94e9dafccfd89f3a3c9782455 100644
--- a/mmseg/models/decode_heads/fpn_head.py
+++ b/mmseg/models/decode_heads/fpn_head.py
@@ -2,7 +2,7 @@ import numpy as np
 import torch.nn as nn
 from mmcv.cnn import ConvModule
 
-from mmseg.ops import resize
+from mmseg.ops import Upsample, resize
 from ..builder import HEADS
 from .decode_head import BaseDecodeHead
 
@@ -45,7 +45,7 @@ class FPNHead(BaseDecodeHead):
                         act_cfg=self.act_cfg))
                 if feature_strides[i] != feature_strides[0]:
                     scale_head.append(
-                        nn.Upsample(
+                        Upsample(
                             scale_factor=2,
                             mode='bilinear',
                             align_corners=self.align_corners))
diff --git a/mmseg/models/decode_heads/setr_mla_head.py b/mmseg/models/decode_heads/setr_mla_head.py
index 016a82a41c547449b003cea56a8a13bde6875dbd..86e493d2e8cd254d399da035d6229661ffb729dd 100644
--- a/mmseg/models/decode_heads/setr_mla_head.py
+++ b/mmseg/models/decode_heads/setr_mla_head.py
@@ -2,6 +2,7 @@ import torch
 import torch.nn as nn
 from mmcv.cnn import ConvModule
 
+from mmseg.ops import Upsample
 from ..builder import HEADS
 from .decode_head import BaseDecodeHead
 
@@ -46,7 +47,7 @@ class SETRMLAHead(BaseDecodeHead):
                         padding=1,
                         norm_cfg=self.norm_cfg,
                         act_cfg=self.act_cfg),
-                    nn.Upsample(
+                    Upsample(
                         scale_factor=up_scale,
                         mode='bilinear',
                         align_corners=self.align_corners)))
diff --git a/mmseg/models/decode_heads/setr_up_head.py b/mmseg/models/decode_heads/setr_up_head.py
index 322a56dc79c4b54afe3fc10396f27a15d4c096f6..d64896f76b2ca23568bc4c6c4d47384d24b09616 100644
--- a/mmseg/models/decode_heads/setr_up_head.py
+++ b/mmseg/models/decode_heads/setr_up_head.py
@@ -1,6 +1,7 @@
 import torch.nn as nn
 from mmcv.cnn import ConvModule, build_norm_layer
 
+from mmseg.ops import Upsample
 from ..builder import HEADS
 from .decode_head import BaseDecodeHead
 
@@ -59,7 +60,7 @@ class SETRUPHead(BaseDecodeHead):
                         padding=int(kernel_size - 1) // 2,
                         norm_cfg=self.norm_cfg,
                         act_cfg=self.act_cfg),
-                    nn.Upsample(
+                    Upsample(
                         scale_factor=up_scale,
                         mode='bilinear',
                         align_corners=self.align_corners)))
diff --git a/mmseg/models/necks/fpn.py b/mmseg/models/necks/fpn.py
index 4ba128ed48c9a078707b30b2fddb47132387b27a..5e1bd21836d4d620e55c63207dfc58ea77a45af9 100644
--- a/mmseg/models/necks/fpn.py
+++ b/mmseg/models/necks/fpn.py
@@ -3,6 +3,7 @@ import torch.nn.functional as F
 from mmcv.cnn import ConvModule
 from mmcv.runner import BaseModule, auto_fp16
 
+from mmseg.ops import resize
 from ..builder import NECKS
 
 
@@ -173,11 +174,10 @@ class FPN(BaseModule):
             # In some cases, fixing `scale factor` (e.g. 2) is preferred, but
             #  it cannot co-exist with `size` in `F.interpolate`.
             if 'scale_factor' in self.upsample_cfg:
-                laterals[i - 1] += F.interpolate(laterals[i],
-                                                 **self.upsample_cfg)
+                laterals[i - 1] += resize(laterals[i], **self.upsample_cfg)
             else:
                 prev_shape = laterals[i - 1].shape[2:]
-                laterals[i - 1] += F.interpolate(
+                laterals[i - 1] += resize(
                     laterals[i], size=prev_shape, **self.upsample_cfg)
 
         # build outputs
diff --git a/mmseg/models/necks/multilevel_neck.py b/mmseg/models/necks/multilevel_neck.py
index eb32240bc6bad8bbe0d8996e783415502ed7d1b4..9f638932f4040179bf6106b7d39f20f2309f6972 100644
--- a/mmseg/models/necks/multilevel_neck.py
+++ b/mmseg/models/necks/multilevel_neck.py
@@ -1,7 +1,7 @@
 import torch.nn as nn
-import torch.nn.functional as F
 from mmcv.cnn import ConvModule, xavier_init
 
+from mmseg.ops import resize
 from ..builder import NECKS
 
 
@@ -70,7 +70,7 @@ class MultiLevelNeck(nn.Module):
             inputs = [inputs[0] for _ in range(self.num_outs)]
         outs = []
         for i in range(self.num_outs):
-            x_resize = F.interpolate(
+            x_resize = resize(
                 inputs[i], scale_factor=self.scales[i], mode='bilinear')
             outs.append(self.convs[i](x_resize))
         return tuple(outs)
diff --git a/tests/test_models/test_backbones/test_unet.py b/tests/test_models/test_backbones/test_unet.py
index defdf392169a4734ba2a5c25d62c6d4a19065fcd..52f2123a3c7fffbd2dab3aef8f48e5782a1d48fc 100644
--- a/tests/test_models/test_backbones/test_unet.py
+++ b/tests/test_models/test_backbones/test_unet.py
@@ -1,10 +1,10 @@
 import pytest
 import torch
 from mmcv.cnn import ConvModule
-from torch import nn
 
 from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule,
                                          InterpConv, UNet, UpConvBlock)
+from mmseg.ops import Upsample
 from .utils import check_norm_state
 
 
@@ -145,7 +145,7 @@ def test_interp_conv():
     block = InterpConv(64, 32, conv_first=False)
     x = torch.randn(1, 64, 128, 128)
     x_out = block(x)
-    assert isinstance(block.interp_upsample[0], nn.Upsample)
+    assert isinstance(block.interp_upsample[0], Upsample)
     assert isinstance(block.interp_upsample[1], ConvModule)
     assert x_out.shape == torch.Size([1, 32, 256, 256])
 
@@ -154,7 +154,7 @@ def test_interp_conv():
     x = torch.randn(1, 64, 128, 128)
     x_out = block(x)
     assert isinstance(block.interp_upsample[0], ConvModule)
-    assert isinstance(block.interp_upsample[1], nn.Upsample)
+    assert isinstance(block.interp_upsample[1], Upsample)
     assert x_out.shape == torch.Size([1, 32, 256, 256])
 
     # test InterpConv with bilinear upsample for upsample 2X.
@@ -166,7 +166,7 @@ def test_interp_conv():
             scale_factor=2, mode='bilinear', align_corners=False))
     x = torch.randn(1, 64, 128, 128)
     x_out = block(x)
-    assert isinstance(block.interp_upsample[0], nn.Upsample)
+    assert isinstance(block.interp_upsample[0], Upsample)
     assert isinstance(block.interp_upsample[1], ConvModule)
     assert x_out.shape == torch.Size([1, 32, 256, 256])
     assert block.interp_upsample[0].mode == 'bilinear'
@@ -179,7 +179,7 @@ def test_interp_conv():
         upsample_cfg=dict(scale_factor=2, mode='nearest'))
     x = torch.randn(1, 64, 128, 128)
     x_out = block(x)
-    assert isinstance(block.interp_upsample[0], nn.Upsample)
+    assert isinstance(block.interp_upsample[0], Upsample)
     assert isinstance(block.interp_upsample[1], ConvModule)
     assert x_out.shape == torch.Size([1, 32, 256, 256])
     assert block.interp_upsample[0].mode == 'nearest'
diff --git a/tools/deploy_test.py b/tools/deploy_test.py
index bef3512d71d585ae4984097ca0467fad458ae320..51f16b4a2a2b9a2586c15ecf1c0bfca7d0f5658d 100644
--- a/tools/deploy_test.py
+++ b/tools/deploy_test.py
@@ -14,6 +14,7 @@ from mmcv.utils import DictAction
 from mmseg.apis import single_gpu_test
 from mmseg.datasets import build_dataloader, build_dataset
 from mmseg.models.segmentors.base import BaseSegmentor
+from mmseg.ops import resize
 
 
 class ONNXRuntimeSegmentor(BaseSegmentor):
@@ -79,7 +80,7 @@ class ONNXRuntimeSegmentor(BaseSegmentor):
         if not (ori_shape[0] == seg_pred.shape[-2]
                 and ori_shape[1] == seg_pred.shape[-1]):
             seg_pred = torch.from_numpy(seg_pred).float()
-            seg_pred = torch.nn.functional.interpolate(
+            seg_pred = resize(
                 seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
             seg_pred = seg_pred.long().detach().cpu().numpy()
         seg_pred = seg_pred[0]
@@ -127,7 +128,7 @@ class TensorRTSegmentor(BaseSegmentor):
         if not (ori_shape[0] == seg_pred.shape[-2]
                 and ori_shape[1] == seg_pred.shape[-1]):
             seg_pred = torch.from_numpy(seg_pred).float()
-            seg_pred = torch.nn.functional.interpolate(
+            seg_pred = resize(
                 seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
             seg_pred = seg_pred.long().detach().cpu().numpy()
         seg_pred = seg_pred[0]
diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py
index 14f25056d512399921210f0db538f93568eb8b2a..17f10932a67c760b04cc322ffbf33ccab238aea0 100644
--- a/tools/pytorch2onnx.py
+++ b/tools/pytorch2onnx.py
@@ -16,6 +16,7 @@ from mmseg.apis import show_result_pyplot
 from mmseg.apis.inference import LoadImage
 from mmseg.datasets.pipelines import Compose
 from mmseg.models import build_segmentor
+from mmseg.ops import resize
 
 torch.manual_seed(3)
 
@@ -210,10 +211,7 @@ def pytorch2onnx(model,
 
         if dynamic_export and test_mode == 'whole':
             # scale image for dynamic shape test
-            img_list = [
-                nn.functional.interpolate(_, scale_factor=1.5)
-                for _ in img_list
-            ]
+            img_list = [resize(_, scale_factor=1.5) for _ in img_list]
             # concate flip image for batch test
             flip_img_list = [_.flip(-1) for _ in img_list]
             img_list = [