From f3f443ff719bf4f17b0e17182b8261bb68b8b1dc Mon Sep 17 00:00:00 2001 From: Jerry Jiarui XU <xvjiarui0826@gmail.com> Date: Fri, 25 Sep 2020 19:56:10 +0800 Subject: [PATCH] [Enhance] Migrate to MMCV DepthwiseSeparableConv (#158) * Add D16-MG124 models * Use MMCV DepthSepConv * add OHEM * add warmup * fixed test * fixed test * change to bs 16 * revert config * add models * seperate --- .dev/clean_models.py | 125 ----------------- .dev/generate_table.py | 152 --------------------- .dev/modelzoo_json2md.py | 58 -------- .dev/upload_modelzoo.py | 44 ++++++ mmseg/core/utils/__init__.py | 3 +- mmseg/core/utils/dist_utils.py | 49 ------- mmseg/models/backbones/fast_scnn.py | 5 +- mmseg/models/decode_heads/sep_aspp_head.py | 4 +- mmseg/models/decode_heads/sep_fcn_head.py | 3 +- mmseg/ops/__init__.py | 3 +- mmseg/ops/separable_conv_module.py | 88 ------------ setup.cfg | 2 +- tests/test_models/test_heads.py | 4 +- tests/test_ops/test_sep_conv_module.py | 71 ---------- 14 files changed, 55 insertions(+), 556 deletions(-) delete mode 100644 .dev/clean_models.py delete mode 100644 .dev/generate_table.py delete mode 100644 .dev/modelzoo_json2md.py create mode 100644 .dev/upload_modelzoo.py delete mode 100644 mmseg/core/utils/dist_utils.py delete mode 100644 mmseg/ops/separable_conv_module.py delete mode 100644 tests/test_ops/test_sep_conv_module.py diff --git a/.dev/clean_models.py b/.dev/clean_models.py deleted file mode 100644 index c9ac2acb..00000000 --- a/.dev/clean_models.py +++ /dev/null @@ -1,125 +0,0 @@ -import argparse -import glob -import json -import os -import os.path as osp - -import mmcv - -# build schedule look-up table to automatically find the final model -SCHEDULES_LUT = { - '20ki': 20000, - '40ki': 40000, - '60ki': 60000, - '80ki': 80000, - '160ki': 160000 -} -RESULTS_LUT = ['mIoU', 'mAcc', 'aAcc'] - - -def get_final_iter(config): - iter_num = SCHEDULES_LUT[config.split('_')[-2]] - return iter_num - - -def get_final_results(log_json_path, iter_num): - result_dict = dict() - with open(log_json_path, 'r') as f: - for line in f.readlines(): - log_line = json.loads(line) - if 'mode' not in log_line.keys(): - continue - - if log_line['mode'] == 'train' and log_line['iter'] == iter_num: - result_dict['memory'] = log_line['memory'] - - if log_line['iter'] == iter_num: - result_dict.update({ - key: log_line[key] - for key in RESULTS_LUT if key in log_line - }) - return result_dict - - -def parse_args(): - parser = argparse.ArgumentParser(description='Gather benchmarked models') - parser.add_argument( - 'root', - type=str, - help='root path of benchmarked models to be gathered') - parser.add_argument( - 'config', - type=str, - help='root path of benchmarked configs to be gathered') - - args = parser.parse_args() - return args - - -def main(): - args = parse_args() - models_root = args.root - config_name = args.config - - # find all models in the root directory to be gathered - raw_configs = list(mmcv.scandir(config_name, '.py', recursive=True)) - - # filter configs that is not trained in the experiments dir - used_configs = [] - for raw_config in raw_configs: - work_dir = osp.splitext(osp.basename(raw_config))[0] - if osp.exists(osp.join(models_root, work_dir)): - used_configs.append(work_dir) - print(f'Find {len(used_configs)} models to be gathered') - - # find final_ckpt and log file for trained each config - # and parse the best performance - model_infos = [] - for used_config in used_configs: - exp_dir = osp.join(models_root, used_config) - # check whether the exps is finished - final_iter = get_final_iter(used_config) - final_model = 'iter_{}.pth'.format(final_iter) - model_path = osp.join(exp_dir, final_model) - - # skip if the model is still training - if not osp.exists(model_path): - print(f'{used_config} not finished yet') - continue - - # get logs - log_json_path = glob.glob(osp.join(exp_dir, '*.log.json'))[0] - log_txt_path = glob.glob(osp.join(exp_dir, '*.log'))[0] - model_performance = get_final_results(log_json_path, final_iter) - - if model_performance is None: - print(f'{used_config} does not have performance') - continue - - model_time = osp.split(log_txt_path)[-1].split('.')[0] - model_infos.append( - dict( - config=used_config, - results=model_performance, - iters=final_iter, - model_time=model_time, - log_json_path=osp.split(log_json_path)[-1])) - - # publish model for each checkpoint - for model in model_infos: - - model_name = osp.split(model['config'])[-1].split('.')[0] - - model_name += '_' + model['model_time'] - for checkpoints in mmcv.scandir( - osp.join(models_root, model['config']), suffix='.pth'): - if checkpoints.endswith(f"iter_{model['iters']}.pth" - ) or checkpoints.endswith('latest.pth'): - continue - print('removing {}'.format( - osp.join(models_root, model['config'], checkpoints))) - os.remove(osp.join(models_root, model['config'], checkpoints)) - - -if __name__ == '__main__': - main() diff --git a/.dev/generate_table.py b/.dev/generate_table.py deleted file mode 100644 index 25142cae..00000000 --- a/.dev/generate_table.py +++ /dev/null @@ -1,152 +0,0 @@ -import argparse -import csv -import glob -import json -import os.path as osp -from collections import OrderedDict - -import mmcv - -# build schedule look-up table to automatically find the final model -RESULTS_LUT = ['mIoU', 'mAcc', 'aAcc'] - - -def get_final_iter(config): - iter_num = config.split('_')[-2] - assert iter_num.endswith('ki') - return int(iter_num[:-2]) * 1000 - - -def get_final_results(log_json_path, iter_num): - result_dict = dict() - with open(log_json_path, 'r') as f: - for line in f.readlines(): - log_line = json.loads(line) - if 'mode' not in log_line.keys(): - continue - - if log_line['mode'] == 'train' and log_line[ - 'iter'] == iter_num - 50: - result_dict['memory'] = log_line['memory'] - - if log_line['iter'] == iter_num: - result_dict.update({ - key: log_line[key] * 100 - for key in RESULTS_LUT if key in log_line - }) - return result_dict - - -def get_total_time(log_json_path, iter_num): - - def convert(seconds): - hour = seconds // 3600 - seconds %= 3600 - minutes = seconds // 60 - seconds %= 60 - - return f'{hour:d}:{minutes:2d}:{seconds:2d}' - - time_dict = dict() - with open(log_json_path, 'r') as f: - last_iter = 0 - total_sec = 0 - for line in f.readlines(): - log_line = json.loads(line) - if 'mode' not in log_line.keys(): - continue - - if log_line['mode'] == 'train': - cur_iter = log_line['iter'] - total_sec += (cur_iter - last_iter) * log_line['time'] - last_iter = cur_iter - time_dict['time'] = convert(int(total_sec)) - - return time_dict - - -def parse_args(): - parser = argparse.ArgumentParser(description='Gather benchmarked models') - parser.add_argument( - 'root', - type=str, - help='root path of benchmarked models to be gathered') - parser.add_argument( - 'config', - type=str, - help='root path of benchmarked configs to be gathered') - parser.add_argument( - 'out', type=str, help='output path of gathered models to be stored') - - args = parser.parse_args() - return args - - -def main(): - args = parse_args() - models_root = args.root - models_out = args.out - config_name = args.config - mmcv.mkdir_or_exist(models_out) - - # find all models in the root directory to be gathered - raw_configs = list(mmcv.scandir(config_name, '.py', recursive=True)) - - # filter configs that is not trained in the experiments dir - exp_dirs = [] - for raw_config in raw_configs: - work_dir = osp.splitext(osp.basename(raw_config))[0] - if osp.exists(osp.join(models_root, work_dir)): - exp_dirs.append(work_dir) - print(f'Find {len(exp_dirs)} models to be gathered') - - # find final_ckpt and log file for trained each config - # and parse the best performance - model_infos = [] - for work_dir in exp_dirs: - exp_dir = osp.join(models_root, work_dir) - # check whether the exps is finished - final_iter = get_final_iter(work_dir) - final_model = 'iter_{}.pth'.format(final_iter) - model_path = osp.join(exp_dir, final_model) - - # skip if the model is still training - if not osp.exists(model_path): - print(f'{model_path} not finished yet') - continue - - # get logs - log_json_path = glob.glob(osp.join(exp_dir, '*.log.json'))[0] - model_performance = get_final_results(log_json_path, final_iter) - - if model_performance is None: - continue - - head = work_dir.split('_')[0] - backbone = work_dir.split('_')[1] - crop_size = work_dir.split('_')[-3] - dataset = work_dir.split('_')[-1] - model_info = OrderedDict( - head=head, - backbone=backbone, - crop_size=crop_size, - dataset=dataset, - iters=f'{final_iter//1000}ki') - model_info.update(model_performance) - model_time = get_total_time(log_json_path, final_iter) - model_info.update(model_time) - model_info['config'] = work_dir - model_infos.append(model_info) - - with open( - osp.join(models_out, 'models_table.csv'), 'w', - newline='') as csvfile: - writer = csv.writer( - csvfile, delimiter='\t', quotechar='|', quoting=csv.QUOTE_MINIMAL) - writer.writerow(model_infos[0].keys()) - for model_info in model_infos: - writer.writerow(model_info.values()) - - -if __name__ == '__main__': - main() diff --git a/.dev/modelzoo_json2md.py b/.dev/modelzoo_json2md.py deleted file mode 100644 index 7cb44bff..00000000 --- a/.dev/modelzoo_json2md.py +++ /dev/null @@ -1,58 +0,0 @@ -import argparse -import os -import os.path as osp - -import mmcv -from pytablewriter import Align, MarkdownTableWriter - - -def parse_args(): - parser = argparse.ArgumentParser(description='Gather benchmarked models') - parser.add_argument('table_cache', type=str, help='table_cache input') - parser.add_argument('out', type=str, help='output path md') - - args = parser.parse_args() - return args - - -def main(): - args = parse_args() - table_cache = mmcv.load(args.table_cache) - output_dir = args.out - - writer = MarkdownTableWriter() - writer.headers = [ - 'Method', 'Backbone', 'Crop Size', 'Lr schd', 'Mem (GB)', - 'Inf time (fps)', 'mIoU', 'mIoU(ms+flip)', 'download' - ] - writer.margin = 1 - writer.align_list = [Align.CENTER] * len(writer.headers) - dataset_maps = { - 'cityscapes': 'Cityscapes', - 'ade20k': 'ADE20K', - 'voc12aug': 'Pascal VOC 2012 + Aug' - } - for directory in table_cache: - for dataset in table_cache[directory]: - table = table_cache[directory][dataset][0] - writer.table_name = dataset_maps[dataset] - writer.value_matrix = table - for i in range(len(table)): - if table[i][-4] != '-': - table[i][-4] = f'{table[i][-4]:.2f}' - mmcv.mkdir_or_exist(osp.join(output_dir, directory)) - writer.dump( - osp.join(output_dir, directory, f'README_{dataset}.md')) - with open(osp.join(output_dir, directory, 'README.md'), 'w') as dst_f: - for dataset in dataset_maps: - dataset_md_file = osp.join(output_dir, directory, - f'README_{dataset}.md') - with open(dataset_md_file) as src_f: - for line in src_f: - dst_f.write(line) - dst_f.write('\n') - os.remove(dataset_md_file) - - -if __name__ == '__main__': - main() diff --git a/.dev/upload_modelzoo.py b/.dev/upload_modelzoo.py new file mode 100644 index 00000000..bd78bc41 --- /dev/null +++ b/.dev/upload_modelzoo.py @@ -0,0 +1,44 @@ +import argparse +import os +import os.path as osp + +import oss2 + +ACCESS_KEY_ID = os.getenv('OSS_ACCESS_KEY_ID', None) +ACCESS_KEY_SECRET = os.getenv('OSS_ACCESS_KEY_SECRET', None) +BUCKET_NAME = 'openmmlab' +ENDPOINT = 'https://oss-accelerate.aliyuncs.com' + + +def parse_args(): + parser = argparse.ArgumentParser(description='Upload models to OSS') + parser.add_argument('model_zoo', type=str, help='model_zoo input') + parser.add_argument( + '--dst-folder', + type=str, + default='mmsegmentation/v0.5', + help='destination folder') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + model_zoo = args.model_zoo + dst_folder = args.dst_folder + bucket = oss2.Bucket( + oss2.Auth(ACCESS_KEY_ID, ACCESS_KEY_SECRET), ENDPOINT, BUCKET_NAME) + + for root, dirs, files in os.walk(model_zoo): + for file in files: + file_path = osp.relpath(osp.join(root, file), model_zoo) + print(f'Uploading {file_path}') + + oss2.resumable_upload(bucket, osp.join(dst_folder, file_path), + osp.join(model_zoo, file_path)) + bucket.put_object_acl( + osp.join(dst_folder, file_path), oss2.OBJECT_ACL_PUBLIC_READ) + + +if __name__ == '__main__': + main() diff --git a/mmseg/core/utils/__init__.py b/mmseg/core/utils/__init__.py index 79d62f02..f2678b32 100644 --- a/mmseg/core/utils/__init__.py +++ b/mmseg/core/utils/__init__.py @@ -1,4 +1,3 @@ -from .dist_utils import allreduce_grads from .misc import add_prefix -__all__ = ['add_prefix', 'allreduce_grads'] +__all__ = ['add_prefix'] diff --git a/mmseg/core/utils/dist_utils.py b/mmseg/core/utils/dist_utils.py deleted file mode 100644 index 25219a79..00000000 --- a/mmseg/core/utils/dist_utils.py +++ /dev/null @@ -1,49 +0,0 @@ -from collections import OrderedDict - -import torch.distributed as dist -from torch._utils import (_flatten_dense_tensors, _take_tensors, - _unflatten_dense_tensors) - - -def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): - if bucket_size_mb > 0: - bucket_size_bytes = bucket_size_mb * 1024 * 1024 - buckets = _take_tensors(tensors, bucket_size_bytes) - else: - buckets = OrderedDict() - for tensor in tensors: - tp = tensor.type() - if tp not in buckets: - buckets[tp] = [] - buckets[tp].append(tensor) - buckets = buckets.values() - - for bucket in buckets: - flat_tensors = _flatten_dense_tensors(bucket) - dist.all_reduce(flat_tensors) - flat_tensors.div_(world_size) - for tensor, synced in zip( - bucket, _unflatten_dense_tensors(flat_tensors, bucket)): - tensor.copy_(synced) - - -def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): - """Allreduce gradients. - - Args: - params (list[torch.Parameters]): List of parameters of a model - coalesce (bool, optional): Whether allreduce parameters as a whole. - Defaults to True. - bucket_size_mb (int, optional): Size of bucket, the unit is MB. - Defaults to -1. - """ - grads = [ - param.grad.data for param in params - if param.requires_grad and param.grad is not None - ] - world_size = dist.get_world_size() - if coalesce: - _allreduce_coalesced(grads, world_size, bucket_size_mb) - else: - for tensor in grads: - dist.all_reduce(tensor.div_(world_size)) diff --git a/mmseg/models/backbones/fast_scnn.py b/mmseg/models/backbones/fast_scnn.py index dcb24214..4aaec221 100644 --- a/mmseg/models/backbones/fast_scnn.py +++ b/mmseg/models/backbones/fast_scnn.py @@ -1,10 +1,11 @@ import torch import torch.nn as nn -from mmcv.cnn import ConvModule, constant_init, kaiming_init +from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, constant_init, + kaiming_init) from torch.nn.modules.batchnorm import _BatchNorm from mmseg.models.decode_heads.psp_head import PPM -from mmseg.ops import DepthwiseSeparableConvModule, resize +from mmseg.ops import resize from mmseg.utils import InvertedResidual from ..builder import BACKBONES diff --git a/mmseg/models/decode_heads/sep_aspp_head.py b/mmseg/models/decode_heads/sep_aspp_head.py index 71881890..50bd52bc 100644 --- a/mmseg/models/decode_heads/sep_aspp_head.py +++ b/mmseg/models/decode_heads/sep_aspp_head.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn -from mmcv.cnn import ConvModule +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule -from mmseg.ops import DepthwiseSeparableConvModule, resize +from mmseg.ops import resize from ..builder import HEADS from .aspp_head import ASPPHead, ASPPModule diff --git a/mmseg/models/decode_heads/sep_fcn_head.py b/mmseg/models/decode_heads/sep_fcn_head.py index 18779512..a636f702 100644 --- a/mmseg/models/decode_heads/sep_fcn_head.py +++ b/mmseg/models/decode_heads/sep_fcn_head.py @@ -1,4 +1,5 @@ -from mmseg.ops import DepthwiseSeparableConvModule +from mmcv.cnn import DepthwiseSeparableConvModule + from ..builder import HEADS from .fcn_head import FCNHead diff --git a/mmseg/ops/__init__.py b/mmseg/ops/__init__.py index 7a0b930c..bec51c75 100644 --- a/mmseg/ops/__init__.py +++ b/mmseg/ops/__init__.py @@ -1,5 +1,4 @@ from .encoding import Encoding -from .separable_conv_module import DepthwiseSeparableConvModule from .wrappers import Upsample, resize -__all__ = ['Upsample', 'resize', 'DepthwiseSeparableConvModule', 'Encoding'] +__all__ = ['Upsample', 'resize', 'Encoding'] diff --git a/mmseg/ops/separable_conv_module.py b/mmseg/ops/separable_conv_module.py deleted file mode 100644 index 4e5922cc..00000000 --- a/mmseg/ops/separable_conv_module.py +++ /dev/null @@ -1,88 +0,0 @@ -import torch.nn as nn -from mmcv.cnn import ConvModule - - -class DepthwiseSeparableConvModule(nn.Module): - """Depthwise separable convolution module. - - See https://arxiv.org/pdf/1704.04861.pdf for details. - - This module can replace a ConvModule with the conv block replaced by two - conv block: depthwise conv block and pointwise conv block. The depthwise - conv block contains depthwise-conv/norm/activation layers. The pointwise - conv block contains pointwise-conv/norm/activation layers. It should be - noted that there will be norm/activation layer in the depthwise conv block - if `norm_cfg` and `act_cfg` are specified. - - Args: - in_channels (int): Same as nn.Conv2d. - out_channels (int): Same as nn.Conv2d. - kernel_size (int or tuple[int]): Same as nn.Conv2d. - stride (int or tuple[int]): Same as nn.Conv2d. Default: 1. - padding (int or tuple[int]): Same as nn.Conv2d. Default: 0. - dilation (int or tuple[int]): Same as nn.Conv2d. Default: 1. - norm_cfg (dict): Default norm config for both depthwise ConvModule and - pointwise ConvModule. Default: None. - act_cfg (dict): Default activation config for both depthwise ConvModule - and pointwise ConvModule. Default: dict(type='ReLU'). - dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is - 'default', it will be the same as `norm_cfg`. Default: 'default'. - dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is - 'default', it will be the same as `act_cfg`. Default: 'default'. - pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is - 'default', it will be the same as `norm_cfg`. Default: 'default'. - pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is - 'default', it will be the same as `act_cfg`. Default: 'default'. - kwargs (optional): Other shared arguments for depthwise and pointwise - ConvModule. See ConvModule for ref. - """ - - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - norm_cfg=None, - act_cfg=dict(type='ReLU'), - dw_norm_cfg='default', - dw_act_cfg='default', - pw_norm_cfg='default', - pw_act_cfg='default', - **kwargs): - super(DepthwiseSeparableConvModule, self).__init__() - assert 'groups' not in kwargs, 'groups should not be specified' - - # if norm/activation config of depthwise/pointwise ConvModule is not - # specified, use default config. - dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg - dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg - pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg - pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg - - # depthwise convolution - self.depthwise_conv = ConvModule( - in_channels, - in_channels, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=in_channels, - norm_cfg=dw_norm_cfg, - act_cfg=dw_act_cfg, - **kwargs) - - self.pointwise_conv = ConvModule( - in_channels, - out_channels, - 1, - norm_cfg=pw_norm_cfg, - act_cfg=pw_act_cfg, - **kwargs) - - def forward(self, x): - x = self.depthwise_conv(x) - x = self.pointwise_conv(x) - return x diff --git a/setup.cfg b/setup.cfg index 21aad54e..cb533f4b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,6 +8,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = setuptools known_first_party = mmseg -known_third_party = PIL,cityscapesscripts,detail,matplotlib,mmcv,numpy,onnxruntime,pytablewriter,pytest,scipy,torch +known_third_party = PIL,cityscapesscripts,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,pytest,scipy,torch no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py index 02460cbc..8e60a915 100644 --- a/tests/test_models/test_heads.py +++ b/tests/test_models/test_heads.py @@ -2,7 +2,7 @@ from unittest.mock import patch import pytest import torch -from mmcv.cnn import ConvModule +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule from mmcv.utils import ConfigDict from mmcv.utils.parrots_wrapper import SyncBatchNorm @@ -557,7 +557,6 @@ def test_sep_fcn_head(): output = head(x) assert output.shape == (2, head.num_classes, 32, 32) assert not head.concat_input - from mmseg.ops.separable_conv_module import DepthwiseSeparableConvModule assert isinstance(head.convs[0], DepthwiseSeparableConvModule) assert isinstance(head.convs[1], DepthwiseSeparableConvModule) assert head.conv_seg.kernel_size == (1, 1) @@ -573,7 +572,6 @@ def test_sep_fcn_head(): output = head(x) assert output.shape == (3, head.num_classes, 32, 32) assert head.concat_input - from mmseg.ops.separable_conv_module import DepthwiseSeparableConvModule assert isinstance(head.convs[0], DepthwiseSeparableConvModule) assert isinstance(head.convs[1], DepthwiseSeparableConvModule) diff --git a/tests/test_ops/test_sep_conv_module.py b/tests/test_ops/test_sep_conv_module.py deleted file mode 100644 index 4eb65011..00000000 --- a/tests/test_ops/test_sep_conv_module.py +++ /dev/null @@ -1,71 +0,0 @@ -import pytest -import torch -import torch.nn as nn - -from mmseg.ops import DepthwiseSeparableConvModule - - -def test_depthwise_separable_conv(): - with pytest.raises(AssertionError): - # conv_cfg must be a dict or None - DepthwiseSeparableConvModule(4, 8, 2, groups=2) - - # test default config - conv = DepthwiseSeparableConvModule(3, 8, 2) - assert conv.depthwise_conv.conv.groups == 3 - assert conv.pointwise_conv.conv.kernel_size == (1, 1) - assert not conv.depthwise_conv.with_norm - assert not conv.pointwise_conv.with_norm - assert conv.depthwise_conv.activate.__class__.__name__ == 'ReLU' - assert conv.pointwise_conv.activate.__class__.__name__ == 'ReLU' - x = torch.rand(1, 3, 256, 256) - output = conv(x) - assert output.shape == (1, 8, 255, 255) - - # test - conv = DepthwiseSeparableConvModule(3, 8, 2, dw_norm_cfg=dict(type='BN')) - assert conv.depthwise_conv.norm_name == 'bn' - assert not conv.pointwise_conv.with_norm - x = torch.rand(1, 3, 256, 256) - output = conv(x) - assert output.shape == (1, 8, 255, 255) - - conv = DepthwiseSeparableConvModule(3, 8, 2, pw_norm_cfg=dict(type='BN')) - assert not conv.depthwise_conv.with_norm - assert conv.pointwise_conv.norm_name == 'bn' - x = torch.rand(1, 3, 256, 256) - output = conv(x) - assert output.shape == (1, 8, 255, 255) - - # add test for ['norm', 'conv', 'act'] - conv = DepthwiseSeparableConvModule(3, 8, 2, order=('norm', 'conv', 'act')) - x = torch.rand(1, 3, 256, 256) - output = conv(x) - assert output.shape == (1, 8, 255, 255) - - conv = DepthwiseSeparableConvModule( - 3, 8, 3, padding=1, with_spectral_norm=True) - assert hasattr(conv.depthwise_conv.conv, 'weight_orig') - assert hasattr(conv.pointwise_conv.conv, 'weight_orig') - output = conv(x) - assert output.shape == (1, 8, 256, 256) - - conv = DepthwiseSeparableConvModule( - 3, 8, 3, padding=1, padding_mode='reflect') - assert isinstance(conv.depthwise_conv.padding_layer, nn.ReflectionPad2d) - output = conv(x) - assert output.shape == (1, 8, 256, 256) - - conv = DepthwiseSeparableConvModule( - 3, 8, 3, padding=1, dw_act_cfg=dict(type='LeakyReLU')) - assert conv.depthwise_conv.activate.__class__.__name__ == 'LeakyReLU' - assert conv.pointwise_conv.activate.__class__.__name__ == 'ReLU' - output = conv(x) - assert output.shape == (1, 8, 256, 256) - - conv = DepthwiseSeparableConvModule( - 3, 8, 3, padding=1, pw_act_cfg=dict(type='LeakyReLU')) - assert conv.depthwise_conv.activate.__class__.__name__ == 'ReLU' - assert conv.pointwise_conv.activate.__class__.__name__ == 'LeakyReLU' - output = conv(x) - assert output.shape == (1, 8, 256, 256) -- GitLab