From fb031c59c8916b27c31d93c5abb3d145a1e6aa9e Mon Sep 17 00:00:00 2001 From: Jerry Jiarui XU <xvjiarui0826@gmail.com> Date: Tue, 27 Apr 2021 23:51:09 -0700 Subject: [PATCH] [Refactor] Use MMCV MODEL_REGISTRY (#515) * [Refactor] Use MMCV MODEL_REGISTRY * fixed args --- mmseg/models/builder.py | 48 ++++++++++++----------------------------- 1 file changed, 14 insertions(+), 34 deletions(-) diff --git a/mmseg/models/builder.py b/mmseg/models/builder.py index c487dcdd..9b68ff88 100644 --- a/mmseg/models/builder.py +++ b/mmseg/models/builder.py @@ -1,56 +1,35 @@ import warnings -from mmcv.utils import Registry, build_from_cfg -from torch import nn +from mmcv.cnn import MODELS as MMCV_MODELS +from mmcv.utils import Registry -BACKBONES = Registry('backbone') -NECKS = Registry('neck') -HEADS = Registry('head') -LOSSES = Registry('loss') -SEGMENTORS = Registry('segmentor') +MODELS = Registry('models', parent=MMCV_MODELS) - -def build(cfg, registry, default_args=None): - """Build a module. - - Args: - cfg (dict, list[dict]): The config of modules, is is either a dict - or a list of configs. - registry (:obj:`Registry`): A registry the module belongs to. - default_args (dict, optional): Default arguments to build the module. - Defaults to None. - - Returns: - nn.Module: A built nn module. - """ - - if isinstance(cfg, list): - modules = [ - build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg - ] - return nn.Sequential(*modules) - else: - return build_from_cfg(cfg, registry, default_args) +BACKBONES = MODELS +NECKS = MODELS +HEADS = MODELS +LOSSES = MODELS +SEGMENTORS = MODELS def build_backbone(cfg): """Build backbone.""" - return build(cfg, BACKBONES) + return BACKBONES.build(cfg) def build_neck(cfg): """Build neck.""" - return build(cfg, NECKS) + return NECKS.build(cfg) def build_head(cfg): """Build head.""" - return build(cfg, HEADS) + return HEADS.build(cfg) def build_loss(cfg): """Build loss.""" - return build(cfg, LOSSES) + return LOSSES.build(cfg) def build_segmentor(cfg, train_cfg=None, test_cfg=None): @@ -63,4 +42,5 @@ def build_segmentor(cfg, train_cfg=None, test_cfg=None): 'train_cfg specified in both outer field and model field ' assert cfg.get('test_cfg') is None or test_cfg is None, \ 'test_cfg specified in both outer field and model field ' - return build(cfg, SEGMENTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) + return SEGMENTORS.build( + cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) -- GitLab