Skip to content
Snippets Groups Projects
Unverified Commit ed9d42a2 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

add GN support for flops computation (#1850)

* add GN support for flops computation

* remove useless lines

* modify the flops computation for gn
parent 9729ca54
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = setuptools
known_first_party = mmdet
known_third_party = Cython,albumentations,cv2,imagecorruptions,matplotlib,mmcv,numpy,pycocotools,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision
known_third_party = Cython,albumentations,asynctest,cv2,imagecorruptions,matplotlib,mmcv,numpy,pycocotools,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
......@@ -33,19 +33,6 @@ from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin
from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd,
_AvgPoolNd, _MaxPoolNd)
CONV_TYPES = (_ConvNd, )
DECONV_TYPES = (_ConvTransposeMixin, )
LINEAR_TYPES = (nn.Linear, )
POOLING_TYPES = (_AvgPoolNd, _MaxPoolNd, _AdaptiveAvgPoolNd,
_AdaptiveMaxPoolNd)
RELU_TYPES = (nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6)
BN_TYPES = (_BatchNorm, )
UPSAMPLE_TYPES = (nn.Upsample, )
SUPPORTED_TYPES = (
CONV_TYPES + DECONV_TYPES + LINEAR_TYPES + POOLING_TYPES + RELU_TYPES +
BN_TYPES + UPSAMPLE_TYPES)
def get_model_complexity_info(model,
input_res,
......@@ -249,10 +236,10 @@ def remove_flops_mask(module):
def is_supported_instance(module):
if isinstance(module, SUPPORTED_TYPES):
return True
else:
return False
for mod in hook_mapping:
if issubclass(type(module), mod):
return True
return False
def empty_flops_counter_hook(module, input, output):
......@@ -285,7 +272,6 @@ def pool_flops_counter_hook(module, input, output):
def bn_flops_counter_hook(module, input, output):
module.affine
input = input[0]
batch_flops = np.prod(input.shape)
......@@ -294,6 +280,17 @@ def bn_flops_counter_hook(module, input, output):
module.__flops__ += int(batch_flops)
def gn_flops_counter_hook(module, input, output):
elems = np.prod(input[0].shape)
# there is no precise FLOPs estimation of computing mean and variance,
# and we just set it 2 * elems: half muladds for computing
# means and half for computing vars
batch_flops = 3 * elems
if module.affine:
batch_flops += elems
module.__flops__ += int(batch_flops)
def deconv_flops_counter_hook(conv_module, input, output):
# Can have multiple inputs, getting the first one
input = input[0]
......@@ -359,6 +356,32 @@ def conv_flops_counter_hook(conv_module, input, output):
conv_module.__flops__ += int(overall_flops)
hook_mapping = {
# conv
_ConvNd: conv_flops_counter_hook,
# deconv
_ConvTransposeMixin: deconv_flops_counter_hook,
# fc
nn.Linear: linear_flops_counter_hook,
# pooling
_AvgPoolNd: pool_flops_counter_hook,
_MaxPoolNd: pool_flops_counter_hook,
_AdaptiveAvgPoolNd: pool_flops_counter_hook,
_AdaptiveMaxPoolNd: pool_flops_counter_hook,
# activation
nn.ReLU: relu_flops_counter_hook,
nn.PReLU: relu_flops_counter_hook,
nn.ELU: relu_flops_counter_hook,
nn.LeakyReLU: relu_flops_counter_hook,
nn.ReLU6: relu_flops_counter_hook,
# normalization
_BatchNorm: bn_flops_counter_hook,
nn.GroupNorm: gn_flops_counter_hook,
# upsample
nn.Upsample: upsample_flops_counter_hook,
}
def batch_counter_hook(module, input, output):
batch_size = 1
if len(input) > 0:
......@@ -372,7 +395,6 @@ def batch_counter_hook(module, input, output):
def add_batch_counter_variables_or_reset(module):
module.__batch_counter__ = 0
......@@ -400,22 +422,11 @@ def add_flops_counter_hook_function(module):
if hasattr(module, '__flops_handle__'):
return
if isinstance(module, CONV_TYPES):
handle = module.register_forward_hook(conv_flops_counter_hook)
elif isinstance(module, RELU_TYPES):
handle = module.register_forward_hook(relu_flops_counter_hook)
elif isinstance(module, LINEAR_TYPES):
handle = module.register_forward_hook(linear_flops_counter_hook)
elif isinstance(module, POOLING_TYPES):
handle = module.register_forward_hook(pool_flops_counter_hook)
elif isinstance(module, BN_TYPES):
handle = module.register_forward_hook(bn_flops_counter_hook)
elif isinstance(module, UPSAMPLE_TYPES):
handle = module.register_forward_hook(upsample_flops_counter_hook)
elif isinstance(module, DECONV_TYPES):
handle = module.register_forward_hook(deconv_flops_counter_hook)
else:
handle = module.register_forward_hook(empty_flops_counter_hook)
for mod_type, counter_hook in hook_mapping.items():
if issubclass(type(module), mod_type):
handle = module.register_forward_hook(counter_hook)
break
module.__flops_handle__ = handle
......
......@@ -46,6 +46,9 @@ def main():
split_line = '=' * 30
print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format(
split_line, input_shape, flops, params))
print('!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify that the '
'flops computation is correct.')
if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment