From fb031c59c8916b27c31d93c5abb3d145a1e6aa9e Mon Sep 17 00:00:00 2001
From: Jerry Jiarui XU <xvjiarui0826@gmail.com>
Date: Tue, 27 Apr 2021 23:51:09 -0700
Subject: [PATCH] [Refactor] Use MMCV MODEL_REGISTRY (#515)

* [Refactor] Use MMCV MODEL_REGISTRY

* fixed args
---
 mmseg/models/builder.py | 48 ++++++++++++-----------------------------
 1 file changed, 14 insertions(+), 34 deletions(-)

diff --git a/mmseg/models/builder.py b/mmseg/models/builder.py
index c487dcdd..9b68ff88 100644
--- a/mmseg/models/builder.py
+++ b/mmseg/models/builder.py
@@ -1,56 +1,35 @@
 import warnings
 
-from mmcv.utils import Registry, build_from_cfg
-from torch import nn
+from mmcv.cnn import MODELS as MMCV_MODELS
+from mmcv.utils import Registry
 
-BACKBONES = Registry('backbone')
-NECKS = Registry('neck')
-HEADS = Registry('head')
-LOSSES = Registry('loss')
-SEGMENTORS = Registry('segmentor')
+MODELS = Registry('models', parent=MMCV_MODELS)
 
-
-def build(cfg, registry, default_args=None):
-    """Build a module.
-
-    Args:
-        cfg (dict, list[dict]): The config of modules, is is either a dict
-            or a list of configs.
-        registry (:obj:`Registry`): A registry the module belongs to.
-        default_args (dict, optional): Default arguments to build the module.
-            Defaults to None.
-
-    Returns:
-        nn.Module: A built nn module.
-    """
-
-    if isinstance(cfg, list):
-        modules = [
-            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
-        ]
-        return nn.Sequential(*modules)
-    else:
-        return build_from_cfg(cfg, registry, default_args)
+BACKBONES = MODELS
+NECKS = MODELS
+HEADS = MODELS
+LOSSES = MODELS
+SEGMENTORS = MODELS
 
 
 def build_backbone(cfg):
     """Build backbone."""
-    return build(cfg, BACKBONES)
+    return BACKBONES.build(cfg)
 
 
 def build_neck(cfg):
     """Build neck."""
-    return build(cfg, NECKS)
+    return NECKS.build(cfg)
 
 
 def build_head(cfg):
     """Build head."""
-    return build(cfg, HEADS)
+    return HEADS.build(cfg)
 
 
 def build_loss(cfg):
     """Build loss."""
-    return build(cfg, LOSSES)
+    return LOSSES.build(cfg)
 
 
 def build_segmentor(cfg, train_cfg=None, test_cfg=None):
@@ -63,4 +42,5 @@ def build_segmentor(cfg, train_cfg=None, test_cfg=None):
         'train_cfg specified in both outer field and model field '
     assert cfg.get('test_cfg') is None or test_cfg is None, \
         'test_cfg specified in both outer field and model field '
-    return build(cfg, SEGMENTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
+    return SEGMENTORS.build(
+        cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
-- 
GitLab