From 45fae72de5d3bf933504348daba5c848f752d4a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E6=98=95=E8=BE=B0?= <xiexinch@outlook.com> Date: Fri, 10 Mar 2023 19:25:47 +0800 Subject: [PATCH] [Feature] Support calculating FLOPs of segmentors (#2706) ## Motivation fix compute flops problems ## Modification Please briefly describe what modification is made in this PR. --- tools/analysis_tools/get_flops.py | 108 ++++++++++++++++++++++++------ 1 file changed, 86 insertions(+), 22 deletions(-) diff --git a/tools/analysis_tools/get_flops.py b/tools/analysis_tools/get_flops.py index 1e8f188e..66b2d52f 100644 --- a/tools/analysis_tools/get_flops.py +++ b/tools/analysis_tools/get_flops.py @@ -1,10 +1,23 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse +import tempfile +from pathlib import Path -from mmcv.cnn import get_model_complexity_info -from mmengine import Config +import torch +from mmengine import Config, DictAction +from mmengine.logging import MMLogger +from mmengine.model import revert_sync_batchnorm +from mmengine.registry import init_default_scope -from mmseg.models import build_segmentor +from mmseg.models import BaseSegmentor +from mmseg.registry import MODELS +from mmseg.structures import SegDataSample + +try: + from mmengine.analysis import get_model_complexity_info + from mmengine.analysis.print_helper import _format_size +except ImportError: + raise ImportError('Please upgrade mmengine >= 0.6.0 to use this script.') def parse_args(): @@ -17,13 +30,33 @@ def parse_args(): nargs='+', default=[2048, 1024], help='input image size') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') args = parser.parse_args() return args -def main(): +def inference(args: argparse.Namespace, logger: MMLogger) -> dict: + config_name = Path(args.config) - args = parse_args() + if not config_name.exists(): + logger.error(f'Config file {config_name} does not exist') + + cfg: Config = Config.fromfile(config_name) + cfg.work_dir = tempfile.TemporaryDirectory().name + cfg.log_level = 'WARN' + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + init_default_scope(cfg.get('scope', 'mmseg')) if len(args.shape) == 1: input_shape = (3, args.shape[0], args.shape[0]) @@ -31,29 +64,60 @@ def main(): input_shape = (3, ) + tuple(args.shape) else: raise ValueError('invalid input shape') + result = {} - cfg = Config.fromfile(args.config) - cfg.model.pretrained = None - model = build_segmentor( - cfg.model, - train_cfg=cfg.get('train_cfg'), - test_cfg=cfg.get('test_cfg')).cuda() + model: BaseSegmentor = MODELS.build(cfg.model) + if hasattr(model, 'auxiliary_head'): + model.auxiliary_head = None + if torch.cuda.is_available(): + model.cuda() + model = revert_sync_batchnorm(model) + result['ori_shape'] = input_shape[-2:] + result['pad_shape'] = input_shape[-2:] + data_batch = { + 'inputs': [torch.rand(input_shape)], + 'data_samples': [SegDataSample(metainfo=result)] + } + data = model.data_preprocessor(data_batch) model.eval() + if cfg.model.decode_head.type in ['MaskFormerHead', 'Mask2FormerHead']: + # TODO: Support MaskFormer and Mask2Former + raise NotImplementedError('MaskFormer and Mask2Former are not ' + 'supported yet.') + outputs = get_model_complexity_info( + model, + input_shape, + inputs=data['inputs'], + show_table=False, + show_arch=False) + result['flops'] = _format_size(outputs['flops']) + result['params'] = _format_size(outputs['params']) + result['compute_type'] = 'direct: randomly generate a picture' + return result - if hasattr(model, 'forward_dummy'): - model.forward = model.forward_dummy - else: - raise NotImplementedError( - 'FLOPs counter is currently not currently supported with {}'. - format(model.__class__.__name__)) - flops, params = get_model_complexity_info(model, input_shape) +def main(): + + args = parse_args() + logger = MMLogger.get_instance(name='MMLogger') + + result = inference(args, logger) split_line = '=' * 30 - print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format( - split_line, input_shape, flops, params)) + ori_shape = result['ori_shape'] + pad_shape = result['pad_shape'] + flops = result['flops'] + params = result['params'] + compute_type = result['compute_type'] + + if pad_shape != ori_shape: + print(f'{split_line}\nUse size divisor set input shape ' + f'from {ori_shape} to {pad_shape}') + print(f'{split_line}\nCompute type: {compute_type}\n' + f'Input shape: {pad_shape}\nFlops: {flops}\n' + f'Params: {params}\n{split_line}') print('!!!Please be cautious if you use the results in papers. ' - 'You may need to check if all ops are supported and verify that the ' - 'flops computation is correct.') + 'You may need to check if all ops are supported and verify ' + 'that the flops computation is correct.') if __name__ == '__main__': -- GitLab