Skip to content
Snippets Groups Projects
Commit 7a32d610 authored by linfangjian.vendor's avatar linfangjian.vendor Committed by zhengmiao
Browse files

[Refactor] Refactor all registries

parent b2abe157
No related branches found
No related tags found
No related merge requests found
Showing
with 41 additions and 46 deletions
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import (OPTIMIZER_BUILDERS, build_optimizer,
build_optimizer_constructor)
from .builder import build_optimizer, build_optimizer_constructor
from .evaluation import * # noqa: F401, F403
from .optimizers import * # noqa: F401, F403
from .seg import * # noqa: F401, F403
from .utils import * # noqa: F401, F403
__all__ = [
'OPTIMIZER_BUILDERS', 'build_optimizer', 'build_optimizer_constructor'
]
__all__ = ['build_optimizer', 'build_optimizer_constructor']
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS
from mmcv.utils import Registry, build_from_cfg
OPTIMIZER_BUILDERS = Registry(
'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS)
from mmseg.registry import OPTIMIZER_CONSTRUCTORS
def build_optimizer_constructor(cfg):
constructor_type = cfg.get('type')
if constructor_type in OPTIMIZER_BUILDERS:
return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
elif constructor_type in MMCV_OPTIMIZER_BUILDERS:
return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS)
if constructor_type in OPTIMIZER_CONSTRUCTORS:
return OPTIMIZER_CONSTRUCTORS.build(cfg)
else:
raise KeyError(f'{constructor_type} is not registered '
'in the optimizer builder registry.')
......
......@@ -2,10 +2,11 @@
import json
import warnings
from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
from mmengine.dist import get_dist_info
from mmengine.optim import DefaultOptimizerConstructor
from mmseg.registry import OPTIMIZER_CONSTRUCTORS
from mmseg.utils import get_root_logger
from ..builder import OPTIMIZER_BUILDERS
def get_layer_id_for_convnext(var_name, max_layer_id):
......@@ -99,7 +100,7 @@ def get_layer_id_for_vit(var_name, max_layer_id):
return max_layer_id - 1
@OPTIMIZER_BUILDERS.register_module()
@OPTIMIZER_CONSTRUCTORS.register_module()
class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
"""Different learning rates are set for different layers of backbone.
......@@ -185,7 +186,7 @@ class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
params.extend(parameter_groups.values())
@OPTIMIZER_BUILDERS.register_module()
@OPTIMIZER_CONSTRUCTORS.register_module()
class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor):
"""Different learning rates are set for different layers of backbone.
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import Registry, build_from_cfg
import warnings
PIXEL_SAMPLERS = Registry('pixel sampler')
from mmseg.registry import TASK_UTILS
PIXEL_SAMPLERS = TASK_UTILS
def build_pixel_sampler(cfg, **default_args):
"""Build pixel sampler for segmentation map."""
return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args)
warnings.warn(
'``build_pixel_sampler`` would be deprecated soon, please use '
'``mmseg.registry.TASK_UTILS.build()`` ')
return TASK_UTILS.build(cfg, default_args=default_args)
......@@ -5,7 +5,7 @@ import mmcv
import numpy as np
from PIL import Image
from .builder import DATASETS
from mmseg.registry import DATASETS
from .custom import CustomDataset
......
......@@ -8,9 +8,10 @@ import numpy as np
import torch
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import Registry, build_from_cfg, digit_version
from mmcv.utils import digit_version
from torch.utils.data import DataLoader
from mmseg.registry import DATASETS, TRANSFORMS
from .samplers import DistributedSampler
if platform.system() != 'Windows':
......@@ -22,8 +23,7 @@ if platform.system() != 'Windows':
soft_limit = min(max(4096, base_soft_limit), hard_limit)
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
DATASETS = Registry('dataset')
PIPELINES = Registry('pipeline')
PIPELINES = TRANSFORMS
def _concat_dataset(cfg, default_args=None):
......@@ -82,7 +82,7 @@ def build_dataset(cfg, default_args=None):
cfg.get('split', None), (list, tuple)):
dataset = _concat_dataset(cfg, default_args)
else:
dataset = build_from_cfg(cfg, DATASETS, default_args)
dataset = DATASETS.build(cfg, default_args=default_args)
return dataset
......
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS
from mmseg.registry import DATASETS
from .custom import CustomDataset
......
......@@ -6,7 +6,7 @@ import numpy as np
from mmcv.utils import print_log
from PIL import Image
from .builder import DATASETS
from mmseg.registry import DATASETS
from .custom import CustomDataset
......
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS
from mmseg.registry import DATASETS
from .custom import CustomDataset
......
......@@ -10,8 +10,8 @@ from prettytable import PrettyTable
from torch.utils.data import Dataset
from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics
from mmseg.registry import DATASETS
from mmseg.utils import get_root_logger
from .builder import DATASETS
from .pipelines import Compose, LoadAnnotations
......
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS
from mmseg.registry import DATASETS
from .cityscapes import CityscapesDataset
......
......@@ -6,10 +6,10 @@ from itertools import chain
import mmcv
import numpy as np
from mmcv.utils import build_from_cfg, print_log
from mmcv.utils import print_log
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
from .builder import DATASETS, PIPELINES
from mmseg.registry import DATASETS, TRANSFORMS
from .cityscapes import CityscapesDataset
......@@ -225,7 +225,7 @@ class MultiImageMixDataset:
for transform in pipeline:
if isinstance(transform, dict):
self.pipeline_types.append(transform['type'])
transform = build_from_cfg(transform, PIPELINES)
transform = TRANSFORMS.build(transform)
self.pipeline.append(transform)
else:
raise TypeError('pipeline must be a dict')
......
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS
from mmseg.registry import DATASETS
from .custom import CustomDataset
......
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS
from mmseg.registry import DATASETS
from .custom import CustomDataset
......
......@@ -3,8 +3,8 @@
import mmcv
from mmcv.utils import print_log
from mmseg.registry import DATASETS
from ..utils import get_root_logger
from .builder import DATASETS
from .custom import CustomDataset
......
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS
from mmseg.registry import DATASETS
from .custom import CustomDataset
......
......@@ -5,7 +5,7 @@ import mmcv
import numpy as np
from PIL import Image
from .builder import DATASETS
from mmseg.registry import DATASETS
from .custom import CustomDataset
......
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS
from mmseg.registry import DATASETS
from .cityscapes import CityscapesDataset
......
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS
from mmseg.registry import DATASETS
from .custom import CustomDataset
......
# Copyright (c) OpenMMLab. All rights reserved.
import collections
from mmcv.utils import build_from_cfg
from mmseg.registry import TRANSFORMS
from ..builder import PIPELINES
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class Compose(object):
"""Compose multiple transforms sequentially.
......@@ -20,7 +18,7 @@ class Compose(object):
self.transforms = []
for transform in transforms:
if isinstance(transform, dict):
transform = build_from_cfg(transform, PIPELINES)
transform = TRANSFORMS.build(transform)
self.transforms.append(transform)
elif callable(transform):
self.transforms.append(transform)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment