From 2e5260b58a0a0aa7d1247c15d86611ba0f629896 Mon Sep 17 00:00:00 2001 From: robin Han <drcut@users.noreply.github.com> Date: Thu, 3 Sep 2020 19:59:13 +0800 Subject: [PATCH] Onnx upsample (#100) * add customized Upsample which can convert to ONNX * support multiply decode head for hrnet * support size for Upsample --- mmseg/models/backbones/hrnet.py | 4 ++-- mmseg/ops/__init__.py | 4 ++-- mmseg/ops/wrappers.py | 32 ++++++++++++++++++++++++++++++-- tools/pytorch2onnx.py | 16 +++++++++++----- 4 files changed, 45 insertions(+), 11 deletions(-) diff --git a/mmseg/models/backbones/hrnet.py b/mmseg/models/backbones/hrnet.py index e4247ba6..33f3ba86 100644 --- a/mmseg/models/backbones/hrnet.py +++ b/mmseg/models/backbones/hrnet.py @@ -4,7 +4,7 @@ from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init, from mmcv.runner import load_checkpoint from mmcv.utils.parrots_wrapper import _BatchNorm -from mmseg.ops import resize +from mmseg.ops import Upsample, resize from mmseg.utils import get_root_logger from ..builder import BACKBONES from .resnet import BasicBlock, Bottleneck @@ -141,7 +141,7 @@ class HRModule(nn.Module): bias=False), build_norm_layer(self.norm_cfg, in_channels[i])[1], # we set align_corners=False for HRNet - nn.Upsample( + Upsample( scale_factor=2**(j - i), mode='bilinear', align_corners=False))) diff --git a/mmseg/ops/__init__.py b/mmseg/ops/__init__.py index 54b0d0b7..7a0b930c 100644 --- a/mmseg/ops/__init__.py +++ b/mmseg/ops/__init__.py @@ -1,5 +1,5 @@ from .encoding import Encoding from .separable_conv_module import DepthwiseSeparableConvModule -from .wrappers import resize +from .wrappers import Upsample, resize -__all__ = ['resize', 'DepthwiseSeparableConvModule', 'Encoding'] +__all__ = ['Upsample', 'resize', 'DepthwiseSeparableConvModule', 'Encoding'] diff --git a/mmseg/ops/wrappers.py b/mmseg/ops/wrappers.py index 0b319767..a6d75527 100644 --- a/mmseg/ops/wrappers.py +++ b/mmseg/ops/wrappers.py @@ -1,5 +1,7 @@ import warnings +import torch +import torch.nn as nn import torch.nn.functional as F @@ -11,8 +13,8 @@ def resize(input, warning=True): if warning: if size is not None and align_corners: - input_h, input_w = input.shape[2:] - output_h, output_w = size + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) if output_h > input_h or output_w > output_h: if ((output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) and (output_h - 1) % (input_h - 1) @@ -22,4 +24,30 @@ def resize(input, 'the output would more aligned if ' f'input size {(input_h, input_w)} is `x+1` and ' f'out size {(output_h, output_w)} is `nx+1`') + if isinstance(size, torch.Size): + size = tuple(int(x) for x in size) return F.interpolate(input, size, scale_factor, mode, align_corners) + + +class Upsample(nn.Module): + + def __init__(self, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None): + super(Upsample, self).__init__() + self.size = size + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + if not self.size: + size = [int(t * self.scale_factor) for t in x.shape[-2:]] + else: + size = self.size + return resize(x, size, None, self.mode, self.align_corners) diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index df84eeb9..b2453667 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -5,6 +5,7 @@ import mmcv import numpy as np import onnxruntime as rt import torch +from torch import nn import torch._C import torch.serialization from mmcv.onnx import register_extra_symbolics @@ -88,7 +89,10 @@ def pytorch2onnx(model, """ model.cpu().eval() - num_classes = model.decode_head.num_classes + if isinstance(model.decode_head, nn.ModuleList): + num_classes = model.decode_head[-1].num_classes + else: + num_classes = model.decode_head.num_classes mm_inputs = _demo_mm_inputs(input_shape, num_classes) @@ -142,7 +146,7 @@ def pytorch2onnx(model, def parse_args(): - parser = argparse.ArgumentParser(description='Convert MMDet to ONNX') + parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX') parser.add_argument('config', help='test config file path') parser.add_argument('--checkpoint', help='checkpoint file', default=None) parser.add_argument('--show', action='store_true', help='show onnx graph') @@ -182,11 +186,13 @@ if __name__ == '__main__': # convert SyncBN to BN segmentor = _convert_batchnorm(segmentor) - num_classes = segmentor.decode_head.num_classes + if isinstance(segmentor.decode_head, nn.ModuleList): + num_classes = segmentor.decode_head[-1].num_classes + else: + num_classes = segmentor.decode_head.num_classes if args.checkpoint: - checkpoint = load_checkpoint( - segmentor, args.checkpoint, map_location='cpu') + load_checkpoint(segmentor, args.checkpoint, map_location='cpu') # conver model to onnx file pytorch2onnx( -- GitLab