diff --git a/tools/analysis_tools/get_flops.py b/tools/analysis_tools/get_flops.py index bd665ec07..f745cd294 100644 --- a/tools/analysis_tools/get_flops.py +++ b/tools/analysis_tools/get_flops.py @@ -3,9 +3,10 @@ import numpy as np import torch -from mmcv import Config, DictAction +from mmengine.config import Config, DictAction from mmrotate.registry import MODELS +from mmrotate.utils import register_all_modules try: from mmcv.cnn import get_model_complexity_info @@ -43,7 +44,7 @@ def parse_args(): def main(): - + register_all_modules() args = parse_args() if len(args.shape) == 1: @@ -64,21 +65,11 @@ def main(): if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) - model = MODELS.build( - cfg.model, - train_cfg=cfg.get('train_cfg'), - test_cfg=cfg.get('test_cfg')) + model = MODELS.build(cfg.model) if torch.cuda.is_available(): model.cuda() model.eval() - if hasattr(model, 'forward_dummy'): - model.forward = model.forward_dummy - else: - raise NotImplementedError( - 'FLOPs counter is currently not currently supported with {}'. - format(model.__class__.__name__)) - flops, params = get_model_complexity_info(model, input_shape) split_line = '=' * 30