From bc2dc1277a90c64b49cb8274cd71821734702d88 Mon Sep 17 00:00:00 2001
From: "q.yao" <streetyao@live.com>
Date: Tue, 13 Apr 2021 02:54:59 +0800
Subject: [PATCH] add dynamic export and visualize to pytorch2onnx (#463)

* add dynamic export and visualize to pytorch2onnx

* update document

* fix lint

* fix dynamic error and add visualization

* fix lint

* update docstring

* update doc

* Update help info for --show

Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>

* fix lint

Co-authored-by: maningsheng <maningsheng@sensetime.com>
Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>
---
 docs/useful_tools.md                       |  26 ++-
 mmseg/apis/inference.py                    |  12 +-
 mmseg/models/segmentors/encoder_decoder.py |   7 +-
 mmseg/ops/wrappers.py                      |   3 -
 tools/pytorch2onnx.py                      | 180 ++++++++++++++++++---
 5 files changed, 202 insertions(+), 26 deletions(-)

diff --git a/docs/useful_tools.md b/docs/useful_tools.md
index 7b2e3fde..8286af83 100644
--- a/docs/useful_tools.md
+++ b/docs/useful_tools.md
@@ -46,10 +46,32 @@ The final output filename will be `psp_r50_512x1024_40ki_cityscapes-{hash id}.pt
 
 We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and ONNX model.
 
-```shell
-python tools/pytorch2onnx.py ${CONFIG_FILE} --checkpoint ${CHECKPOINT_FILE} --output-file ${ONNX_FILE} [--shape ${INPUT_SHAPE} --verify]
+```bash
+python tools/pytorch2onnx.py \
+    ${CONFIG_FILE} \
+    --checkpoint ${CHECKPOINT_FILE} \
+    --output-file ${ONNX_FILE} \
+    --input-img ${INPUT_IMG} \
+    --shape ${INPUT_SHAPE} \
+    --show \
+    --verify \
+    --dynamic-export \
+    --cfg-options \
+      model.test_cfg.mode="whole"
 ```
 
+Description of arguments:
+
+- `config` : The path of a model config file.
+- `--checkpoint` : The path of a model checkpoint file.
+- `--output-file`: The path of output ONNX model. If not specified, it will be set to `tmp.onnx`.
+- `--input-img` : The path of an input image for conversion and visualize.
+- `--shape`: The height and width of input tensor to the model. If not specified, it will be set to `256 256`.
+- `--show`: Determines whether to print the architecture of the exported model. If not specified, it will be set to `False`.
+- `--verify`: Determines whether to verify the correctness of an exported model. If not specified, it will be set to `False`.
+- `--dynamic-export`: Determines whether to export ONNX model with dynamic input and output shapes. If not specified, it will be set to `False`.
+- `--cfg-options`:Update config options.
+
 **Note**: This tool is still experimental. Some customized operators are not supported for now.
 
 ## Miscellaneous
diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py
index 9052cdd3..bf875cb2 100644
--- a/mmseg/apis/inference.py
+++ b/mmseg/apis/inference.py
@@ -103,7 +103,9 @@ def show_result_pyplot(model,
                        result,
                        palette=None,
                        fig_size=(15, 10),
-                       opacity=0.5):
+                       opacity=0.5,
+                       title='',
+                       block=True):
     """Visualize the segmentation results on the image.
 
     Args:
@@ -117,6 +119,10 @@ def show_result_pyplot(model,
         opacity(float): Opacity of painted segmentation map.
             Default 0.5.
             Must be in (0, 1] range.
+        title (str): The title of pyplot figure.
+            Default is ''.
+        block (bool): Whether to block the pyplot figure.
+            Default is True.
     """
     if hasattr(model, 'module'):
         model = model.module
@@ -124,4 +130,6 @@ def show_result_pyplot(model,
         img, result, palette=palette, show=False, opacity=opacity)
     plt.figure(figsize=fig_size)
     plt.imshow(mmcv.bgr2rgb(img))
-    plt.show()
+    plt.title(title)
+    plt.tight_layout()
+    plt.show(block=block)
diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py
index 2284906e..b2d067dc 100644
--- a/mmseg/models/segmentors/encoder_decoder.py
+++ b/mmseg/models/segmentors/encoder_decoder.py
@@ -216,9 +216,14 @@ class EncoderDecoder(BaseSegmentor):
 
         seg_logit = self.encode_decode(img, img_meta)
         if rescale:
+            # support dynamic shape for onnx
+            if torch.onnx.is_in_onnx_export():
+                size = img.shape[2:]
+            else:
+                size = img_meta[0]['ori_shape'][:2]
             seg_logit = resize(
                 seg_logit,
-                size=img_meta[0]['ori_shape'][:2],
+                size=size,
                 mode='bilinear',
                 align_corners=self.align_corners,
                 warning=False)
diff --git a/mmseg/ops/wrappers.py b/mmseg/ops/wrappers.py
index a6d75527..0ed9a0cb 100644
--- a/mmseg/ops/wrappers.py
+++ b/mmseg/ops/wrappers.py
@@ -1,6 +1,5 @@
 import warnings
 
-import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
@@ -24,8 +23,6 @@ def resize(input,
                         'the output would more aligned if '
                         f'input size {(input_h, input_w)} is `x+1` and '
                         f'out size {(output_h, output_w)} is `nx+1`')
-    if isinstance(size, torch.Size):
-        size = tuple(int(x) for x in size)
     return F.interpolate(input, size, scale_factor, mode, align_corners)
 
 
diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py
index 2ec9feb5..71f1bb72 100644
--- a/tools/pytorch2onnx.py
+++ b/tools/pytorch2onnx.py
@@ -7,10 +7,14 @@ import onnxruntime as rt
 import torch
 import torch._C
 import torch.serialization
+from mmcv import DictAction
 from mmcv.onnx import register_extra_symbolics
 from mmcv.runner import load_checkpoint
 from torch import nn
 
+from mmseg.apis import show_result_pyplot
+from mmseg.apis.inference import LoadImage
+from mmseg.datasets.pipelines import Compose
 from mmseg.models import build_segmentor
 
 torch.manual_seed(3)
@@ -67,25 +71,61 @@ def _demo_mm_inputs(input_shape, num_classes):
     return mm_inputs
 
 
+def _prepare_input_img(img_path, test_pipeline, shape=None):
+    # build the data pipeline
+    if shape is not None:
+        test_pipeline[1]['img_scale'] = shape
+    test_pipeline[1]['transforms'][0]['keep_ratio'] = False
+    test_pipeline = [LoadImage()] + test_pipeline[1:]
+    test_pipeline = Compose(test_pipeline)
+    # prepare data
+    data = dict(img=img_path)
+    data = test_pipeline(data)
+    imgs = data['img']
+    img_metas = [i.data for i in data['img_metas']]
+
+    mm_inputs = {'imgs': imgs, 'img_metas': img_metas}
+
+    return mm_inputs
+
+
+def _update_input_img(img_list, img_meta_list):
+    # update img and its meta list
+    N, C, H, W = img_list[0].shape
+    img_meta = img_meta_list[0][0]
+    new_img_meta_list = [[{
+        'img_shape': (H, W, C),
+        'ori_shape': (H, W, C),
+        'pad_shape': (H, W, C),
+        'filename': img_meta['filename'],
+        'scale_factor': 1.,
+        'flip': False,
+    } for _ in range(N)]]
+
+    return img_list, new_img_meta_list
+
+
 def pytorch2onnx(model,
-                 input_shape,
+                 mm_inputs,
                  opset_version=11,
                  show=False,
                  output_file='tmp.onnx',
-                 verify=False):
+                 verify=False,
+                 dynamic_export=False):
     """Export Pytorch model to ONNX model and verify the outputs are same
     between Pytorch and ONNX.
 
     Args:
         model (nn.Module): Pytorch model we want to export.
-        input_shape (tuple): Use this input shape to construct
-            the corresponding dummy input and execute the model.
+        mm_inputs (dict): Contain the input tensors and img_metas information.
         opset_version (int): The onnx op version. Default: 11.
         show (bool): Whether print the computation graph. Default: False.
         output_file (string): The path to where we store the output ONNX model.
             Default: `tmp.onnx`.
         verify (bool): Whether compare the outputs between Pytorch and ONNX.
             Default: False.
+        dynamic_export (bool): Whether to export ONNX with dynamic axis.
+            Default: False.
     """
     model.cpu().eval()
 
@@ -94,28 +134,45 @@ def pytorch2onnx(model,
     else:
         num_classes = model.decode_head.num_classes
 
-    mm_inputs = _demo_mm_inputs(input_shape, num_classes)
-
     imgs = mm_inputs.pop('imgs')
     img_metas = mm_inputs.pop('img_metas')
+    ori_shape = img_metas[0]['ori_shape']
 
     img_list = [img[None, :] for img in imgs]
     img_meta_list = [[img_meta] for img_meta in img_metas]
+    img_list, img_meta_list = _update_input_img(img_list, img_meta_list)
 
     # replace original forward function
     origin_forward = model.forward
     model.forward = partial(
         model.forward, img_metas=img_meta_list, return_loss=False)
+    dynamic_axes = None
+    if dynamic_export:
+        dynamic_axes = {
+            'input': {
+                0: 'batch',
+                2: 'height',
+                3: 'width'
+            },
+            'output': {
+                1: 'batch',
+                2: 'height',
+                3: 'width'
+            }
+        }
 
     register_extra_symbolics(opset_version)
     with torch.no_grad():
         torch.onnx.export(
             model, (img_list, ),
             output_file,
+            input_names=['input'],
+            output_names=['output'],
             export_params=True,
-            keep_initializers_as_inputs=True,
+            keep_initializers_as_inputs=False,
             verbose=show,
-            opset_version=opset_version)
+            opset_version=opset_version,
+            dynamic_axes=dynamic_axes)
         print(f'Successfully exported ONNX model: {output_file}')
     model.forward = origin_forward
 
@@ -125,9 +182,28 @@ def pytorch2onnx(model,
         onnx_model = onnx.load(output_file)
         onnx.checker.check_model(onnx_model)
 
+        if dynamic_export:
+            # scale image for dynamic shape test
+            img_list = [
+                nn.functional.interpolate(_, scale_factor=1.5)
+                for _ in img_list
+            ]
+            # concate flip image for batch test
+            flip_img_list = [_.flip(-1) for _ in img_list]
+            img_list = [
+                torch.cat((ori_img, flip_img), 0)
+                for ori_img, flip_img in zip(img_list, flip_img_list)
+            ]
+
+            # update img_meta
+            img_list, img_meta_list = _update_input_img(
+                img_list, img_meta_list)
+
         # check the numerical value
         # get pytorch output
-        pytorch_result = model(img_list, img_meta_list, return_loss=False)[0]
+        with torch.no_grad():
+            pytorch_result = model(img_list, img_meta_list, return_loss=False)
+            pytorch_result = np.stack(pytorch_result, 0)
 
         # get onnx output
         input_all = [node.name for node in onnx_model.graph.input]
@@ -138,10 +214,42 @@ def pytorch2onnx(model,
         assert (len(net_feed_input) == 1)
         sess = rt.InferenceSession(output_file)
         onnx_result = sess.run(
-            None, {net_feed_input[0]: img_list[0].detach().numpy()})[0]
-        if not np.allclose(pytorch_result, onnx_result):
-            raise ValueError(
-                'The outputs are different between Pytorch and ONNX')
+            None, {net_feed_input[0]: img_list[0].detach().numpy()})[0][0]
+        # show segmentation results
+        if show:
+            import cv2
+            import os.path as osp
+            img = img_meta_list[0][0]['filename']
+            if not osp.exists(img):
+                img = imgs[0][:3, ...].permute(1, 2, 0) * 255
+                img = img.detach().numpy().astype(np.uint8)
+            # resize onnx_result to ori_shape
+            onnx_result_ = cv2.resize(onnx_result[0].astype(np.uint8),
+                                      (ori_shape[1], ori_shape[0]))
+            show_result_pyplot(
+                model,
+                img, (onnx_result_, ),
+                palette=model.PALETTE,
+                block=False,
+                title='ONNXRuntime',
+                opacity=0.5)
+
+            # resize pytorch_result to ori_shape
+            pytorch_result_ = cv2.resize(pytorch_result[0].astype(np.uint8),
+                                         (ori_shape[1], ori_shape[0]))
+            show_result_pyplot(
+                model,
+                img, (pytorch_result_, ),
+                title='PyTorch',
+                palette=model.PALETTE,
+                opacity=0.5)
+        # compare results
+        np.testing.assert_allclose(
+            pytorch_result.astype(np.float32) / num_classes,
+            onnx_result.astype(np.float32) / num_classes,
+            rtol=1e-5,
+            atol=1e-5,
+            err_msg='The outputs are different between Pytorch and ONNX')
         print('The outputs are same between Pytorch and ONNX')
 
 
@@ -149,7 +257,12 @@ def parse_args():
     parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX')
     parser.add_argument('config', help='test config file path')
     parser.add_argument('--checkpoint', help='checkpoint file', default=None)
-    parser.add_argument('--show', action='store_true', help='show onnx graph')
+    parser.add_argument(
+        '--input-img', type=str, help='Images for input', default=None)
+    parser.add_argument(
+        '--show',
+        action='store_true',
+        help='show onnx graph and segmentation results')
     parser.add_argument(
         '--verify', action='store_true', help='verify the onnx model')
     parser.add_argument('--output-file', type=str, default='tmp.onnx')
@@ -160,6 +273,20 @@ def parse_args():
         nargs='+',
         default=[256, 256],
         help='input image size')
+    parser.add_argument(
+        '--cfg-options',
+        nargs='+',
+        action=DictAction,
+        help='Override some settings in the used config, the key-value pair '
+        'in xxx=yyy format will be merged into config file. If the value to '
+        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
+        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
+        'Note that the quotation marks are necessary and that no white space '
+        'is allowed.')
+    parser.add_argument(
+        '--dynamic-export',
+        action='store_true',
+        help='Whether to export onnx with dynamic axis.')
     args = parser.parse_args()
     return args
 
@@ -178,6 +305,8 @@ if __name__ == '__main__':
         raise ValueError('invalid input shape')
 
     cfg = mmcv.Config.fromfile(args.config)
+    if args.cfg_options is not None:
+        cfg.merge_from_dict(args.cfg_options)
     cfg.model.pretrained = None
 
     # build the model and load checkpoint
@@ -188,13 +317,28 @@ if __name__ == '__main__':
     segmentor = _convert_batchnorm(segmentor)
 
     if args.checkpoint:
-        load_checkpoint(segmentor, args.checkpoint, map_location='cpu')
+        checkpoint = load_checkpoint(
+            segmentor, args.checkpoint, map_location='cpu')
+        segmentor.CLASSES = checkpoint['meta']['CLASSES']
+        segmentor.PALETTE = checkpoint['meta']['PALETTE']
+
+    # read input or create dummpy input
+    if args.input_img is not None:
+        mm_inputs = _prepare_input_img(args.input_img, cfg.data.test.pipeline,
+                                       (input_shape[3], input_shape[2]))
+    else:
+        if isinstance(segmentor.decode_head, nn.ModuleList):
+            num_classes = segmentor.decode_head[-1].num_classes
+        else:
+            num_classes = segmentor.decode_head.num_classes
+        mm_inputs = _demo_mm_inputs(input_shape, num_classes)
 
-    # conver model to onnx file
+    # convert model to onnx file
     pytorch2onnx(
         segmentor,
-        input_shape,
+        mm_inputs,
         opset_version=args.opset_version,
         show=args.show,
         output_file=args.output_file,
-        verify=args.verify)
+        verify=args.verify,
+        dynamic_export=args.dynamic_export)
-- 
GitLab