66 * this file. If not visit https://opensource.org/licenses/MIT
77'''
88
9+ from functools import partial
10+
911import numpy as np
1012import torch
1113import torch .nn as nn
@@ -55,12 +57,11 @@ def bn_flops_counter_hook(module, input, output):
5557 module .__flops__ += int (batch_flops )
5658
5759
58- def conv_flops_counter_hook (conv_module , input , output , extra_per_position_flops = 0 ):
60+ def conv_flops_counter_hook (conv_module , input , output , extra_per_position_flops = 0 , transpose = False ):
5961 # Can have multiple inputs, getting the first one
6062 input = input [0 ]
6163
6264 batch_size = input .shape [0 ]
63- output_dims = list (output .shape [2 :])
6465
6566 kernel_dims = list (conv_module .kernel_size )
6667 in_channels = conv_module .in_channels
@@ -71,7 +72,12 @@ def conv_flops_counter_hook(conv_module, input, output, extra_per_position_flops
7172 conv_per_position_flops = int (np .prod (kernel_dims , dtype = np .int64 )) * \
7273 (in_channels * filters_per_channel + extra_per_position_flops )
7374
74- active_elements_count = batch_size * int (np .prod (output_dims , dtype = np .int64 ))
75+ if transpose :
76+ input_dims = list (input .shape [2 :])
77+ active_elements_count = batch_size * int (np .prod (input_dims , dtype = np .int64 ))
78+ else :
79+ output_dims = list (output .shape [2 :])
80+ active_elements_count = batch_size * int (np .prod (output_dims , dtype = np .int64 ))
7581
7682 overall_conv_flops = conv_per_position_flops * active_elements_count
7783
@@ -301,9 +307,9 @@ def timm_attention_counter_hook(attention_module, input, output):
301307 # Upscale
302308 nn .Upsample : upsample_flops_counter_hook ,
303309 # Deconvolution
304- nn .ConvTranspose1d : conv_flops_counter_hook ,
305- nn .ConvTranspose2d : conv_flops_counter_hook ,
306- nn .ConvTranspose3d : conv_flops_counter_hook ,
310+ nn .ConvTranspose1d : partial ( conv_flops_counter_hook , transpose = True ) ,
311+ nn .ConvTranspose2d : partial ( conv_flops_counter_hook , transpose = True ) ,
312+ nn .ConvTranspose3d : partial ( conv_flops_counter_hook , transpose = True ) ,
307313 # RNN
308314 nn .RNN : rnn_flops_counter_hook ,
309315 nn .GRU : rnn_flops_counter_hook ,
0 commit comments