From 07cc26ae5a16ad7f0814ad0f0f43586f67fe3ee3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E8=B0=A2=E6=98=95=E8=BE=B0?= <xinchen.xie@qq.com>
Date: Sun, 25 Apr 2021 12:22:09 +0800
Subject: [PATCH] add upsample neck (#512)

* init

* upsample v1.0

* fix errors

* change to in_channels list

* add unittest, docstring, norm/act config and rename

Co-authored-by: xiexinch <test767803@foxmail.com>
---
 mmseg/models/necks/__init__.py                |  3 +-
 mmseg/models/necks/multilevel_neck.py         | 70 +++++++++++++++++++
 .../test_necks/test_multilevel_neck.py        | 28 ++++++++
 3 files changed, 100 insertions(+), 1 deletion(-)
 create mode 100644 mmseg/models/necks/multilevel_neck.py
 create mode 100644 tests/test_models/test_necks/test_multilevel_neck.py

diff --git a/mmseg/models/necks/__init__.py b/mmseg/models/necks/__init__.py
index 0093021e..9b9d3d5b 100644
--- a/mmseg/models/necks/__init__.py
+++ b/mmseg/models/necks/__init__.py
@@ -1,3 +1,4 @@
 from .fpn import FPN
+from .multilevel_neck import MultiLevelNeck
 
-__all__ = ['FPN']
+__all__ = ['FPN', 'MultiLevelNeck']
diff --git a/mmseg/models/necks/multilevel_neck.py b/mmseg/models/necks/multilevel_neck.py
new file mode 100644
index 00000000..7e13813b
--- /dev/null
+++ b/mmseg/models/necks/multilevel_neck.py
@@ -0,0 +1,70 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class MultiLevelNeck(nn.Module):
+    """MultiLevelNeck.
+
+    A neck structure connect vit backbone and decoder_heads.
+    Args:
+        in_channels (List[int]): Number of input channels per scale.
+        out_channels (int): Number of output channels (used at each scale).
+        scales (List[int]): Scale factors for each input feature map.
+        norm_cfg (dict): Config dict for normalization layer. Default: None.
+        act_cfg (dict): Config dict for activation layer in ConvModule.
+            Default: None.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 scales=[0.5, 1, 2, 4],
+                 norm_cfg=None,
+                 act_cfg=None):
+        super(MultiLevelNeck, self).__init__()
+        assert isinstance(in_channels, list)
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.scales = scales
+        self.num_outs = len(scales)
+        self.lateral_convs = nn.ModuleList()
+        self.convs = nn.ModuleList()
+        for in_channel in in_channels:
+            self.lateral_convs.append(
+                ConvModule(
+                    in_channel,
+                    out_channels,
+                    kernel_size=1,
+                    norm_cfg=norm_cfg,
+                    act_cfg=act_cfg))
+        for _ in range(self.num_outs):
+            self.convs.append(
+                ConvModule(
+                    out_channels,
+                    out_channels,
+                    kernel_size=3,
+                    padding=1,
+                    stride=1,
+                    norm_cfg=norm_cfg,
+                    act_cfg=act_cfg))
+
+    def forward(self, inputs):
+        assert len(inputs) == len(self.in_channels)
+        print(inputs[0].shape)
+        inputs = [
+            lateral_conv(inputs[i])
+            for i, lateral_conv in enumerate(self.lateral_convs)
+        ]
+        # for len(inputs) not equal to self.num_outs
+        if len(inputs) == 1:
+            inputs = [inputs[0] for _ in range(self.num_outs)]
+        outs = []
+        for i in range(self.num_outs):
+            x_resize = F.interpolate(
+                inputs[i], scale_factor=self.scales[i], mode='bilinear')
+            outs.append(self.convs[i](x_resize))
+        return tuple(outs)
diff --git a/tests/test_models/test_necks/test_multilevel_neck.py b/tests/test_models/test_necks/test_multilevel_neck.py
new file mode 100644
index 00000000..8fb2fc92
--- /dev/null
+++ b/tests/test_models/test_necks/test_multilevel_neck.py
@@ -0,0 +1,28 @@
+import torch
+
+from mmseg.models import MultiLevelNeck
+
+
+def test_multilevel_neck():
+
+    # Test multi feature maps
+    in_channels = [256, 512, 1024, 2048]
+    inputs = [torch.randn(1, c, 14, 14) for i, c in enumerate(in_channels)]
+
+    neck = MultiLevelNeck(in_channels, 256)
+    outputs = neck(inputs)
+    assert outputs[0].shape == torch.Size([1, 256, 7, 7])
+    assert outputs[1].shape == torch.Size([1, 256, 14, 14])
+    assert outputs[2].shape == torch.Size([1, 256, 28, 28])
+    assert outputs[3].shape == torch.Size([1, 256, 56, 56])
+
+    # Test one feature map
+    in_channels = [768]
+    inputs = [torch.randn(1, 768, 14, 14)]
+
+    neck = MultiLevelNeck(in_channels, 256)
+    outputs = neck(inputs)
+    assert outputs[0].shape == torch.Size([1, 256, 7, 7])
+    assert outputs[1].shape == torch.Size([1, 256, 14, 14])
+    assert outputs[2].shape == torch.Size([1, 256, 28, 28])
+    assert outputs[3].shape == torch.Size([1, 256, 56, 56])
-- 
GitLab