From 2e5260b58a0a0aa7d1247c15d86611ba0f629896 Mon Sep 17 00:00:00 2001
From: robin Han <drcut@users.noreply.github.com>
Date: Thu, 3 Sep 2020 19:59:13 +0800
Subject: [PATCH] Onnx upsample (#100)

* add customized Upsample which can convert to ONNX

* support multiply decode head for hrnet

* support size for Upsample
---
 mmseg/models/backbones/hrnet.py |  4 ++--
 mmseg/ops/__init__.py           |  4 ++--
 mmseg/ops/wrappers.py           | 32 ++++++++++++++++++++++++++++++--
 tools/pytorch2onnx.py           | 16 +++++++++++-----
 4 files changed, 45 insertions(+), 11 deletions(-)

diff --git a/mmseg/models/backbones/hrnet.py b/mmseg/models/backbones/hrnet.py
index e4247ba6..33f3ba86 100644
--- a/mmseg/models/backbones/hrnet.py
+++ b/mmseg/models/backbones/hrnet.py
@@ -4,7 +4,7 @@ from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
 from mmcv.runner import load_checkpoint
 from mmcv.utils.parrots_wrapper import _BatchNorm
 
-from mmseg.ops import resize
+from mmseg.ops import Upsample, resize
 from mmseg.utils import get_root_logger
 from ..builder import BACKBONES
 from .resnet import BasicBlock, Bottleneck
@@ -141,7 +141,7 @@ class HRModule(nn.Module):
                                 bias=False),
                             build_norm_layer(self.norm_cfg, in_channels[i])[1],
                             # we set align_corners=False for HRNet
-                            nn.Upsample(
+                            Upsample(
                                 scale_factor=2**(j - i),
                                 mode='bilinear',
                                 align_corners=False)))
diff --git a/mmseg/ops/__init__.py b/mmseg/ops/__init__.py
index 54b0d0b7..7a0b930c 100644
--- a/mmseg/ops/__init__.py
+++ b/mmseg/ops/__init__.py
@@ -1,5 +1,5 @@
 from .encoding import Encoding
 from .separable_conv_module import DepthwiseSeparableConvModule
-from .wrappers import resize
+from .wrappers import Upsample, resize
 
-__all__ = ['resize', 'DepthwiseSeparableConvModule', 'Encoding']
+__all__ = ['Upsample', 'resize', 'DepthwiseSeparableConvModule', 'Encoding']
diff --git a/mmseg/ops/wrappers.py b/mmseg/ops/wrappers.py
index 0b319767..a6d75527 100644
--- a/mmseg/ops/wrappers.py
+++ b/mmseg/ops/wrappers.py
@@ -1,5 +1,7 @@
 import warnings
 
+import torch
+import torch.nn as nn
 import torch.nn.functional as F
 
 
@@ -11,8 +13,8 @@ def resize(input,
            warning=True):
     if warning:
         if size is not None and align_corners:
-            input_h, input_w = input.shape[2:]
-            output_h, output_w = size
+            input_h, input_w = tuple(int(x) for x in input.shape[2:])
+            output_h, output_w = tuple(int(x) for x in size)
             if output_h > input_h or output_w > output_h:
                 if ((output_h > 1 and output_w > 1 and input_h > 1
                      and input_w > 1) and (output_h - 1) % (input_h - 1)
@@ -22,4 +24,30 @@ def resize(input,
                         'the output would more aligned if '
                         f'input size {(input_h, input_w)} is `x+1` and '
                         f'out size {(output_h, output_w)} is `nx+1`')
+    if isinstance(size, torch.Size):
+        size = tuple(int(x) for x in size)
     return F.interpolate(input, size, scale_factor, mode, align_corners)
+
+
+class Upsample(nn.Module):
+
+    def __init__(self,
+                 size=None,
+                 scale_factor=None,
+                 mode='nearest',
+                 align_corners=None):
+        super(Upsample, self).__init__()
+        self.size = size
+        if isinstance(scale_factor, tuple):
+            self.scale_factor = tuple(float(factor) for factor in scale_factor)
+        else:
+            self.scale_factor = float(scale_factor) if scale_factor else None
+        self.mode = mode
+        self.align_corners = align_corners
+
+    def forward(self, x):
+        if not self.size:
+            size = [int(t * self.scale_factor) for t in x.shape[-2:]]
+        else:
+            size = self.size
+        return resize(x, size, None, self.mode, self.align_corners)
diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py
index df84eeb9..b2453667 100644
--- a/tools/pytorch2onnx.py
+++ b/tools/pytorch2onnx.py
@@ -5,6 +5,7 @@ import mmcv
 import numpy as np
 import onnxruntime as rt
 import torch
+from torch import nn
 import torch._C
 import torch.serialization
 from mmcv.onnx import register_extra_symbolics
@@ -88,7 +89,10 @@ def pytorch2onnx(model,
     """
     model.cpu().eval()
 
-    num_classes = model.decode_head.num_classes
+    if isinstance(model.decode_head, nn.ModuleList):
+        num_classes = model.decode_head[-1].num_classes
+    else:
+        num_classes = model.decode_head.num_classes
 
     mm_inputs = _demo_mm_inputs(input_shape, num_classes)
 
@@ -142,7 +146,7 @@ def pytorch2onnx(model,
 
 
 def parse_args():
-    parser = argparse.ArgumentParser(description='Convert MMDet to ONNX')
+    parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX')
     parser.add_argument('config', help='test config file path')
     parser.add_argument('--checkpoint', help='checkpoint file', default=None)
     parser.add_argument('--show', action='store_true', help='show onnx graph')
@@ -182,11 +186,13 @@ if __name__ == '__main__':
     # convert SyncBN to BN
     segmentor = _convert_batchnorm(segmentor)
 
-    num_classes = segmentor.decode_head.num_classes
+    if isinstance(segmentor.decode_head, nn.ModuleList):
+        num_classes = segmentor.decode_head[-1].num_classes
+    else:
+        num_classes = segmentor.decode_head.num_classes
 
     if args.checkpoint:
-        checkpoint = load_checkpoint(
-            segmentor, args.checkpoint, map_location='cpu')
+        load_checkpoint(segmentor, args.checkpoint, map_location='cpu')
 
     # conver model to onnx file
     pytorch2onnx(
-- 
GitLab