Skip to content

Commit fb3aa21

Browse files
authored
Merge pull request #150 from FrzMtrsprt/fix-convt
Fix ConvTranspose
2 parents 733005e + aba9753 commit fb3aa21

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

ptflops/pytorch_ops.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
* this file. If not visit https://opensource.org/licenses/MIT
77
'''
88

9+
from functools import partial
10+
911
import numpy as np
1012
import torch
1113
import torch.nn as nn
@@ -55,12 +57,11 @@ def bn_flops_counter_hook(module, input, output):
5557
module.__flops__ += int(batch_flops)
5658

5759

58-
def conv_flops_counter_hook(conv_module, input, output, extra_per_position_flops=0):
60+
def conv_flops_counter_hook(conv_module, input, output, extra_per_position_flops=0, transpose=False):
5961
# Can have multiple inputs, getting the first one
6062
input = input[0]
6163

6264
batch_size = input.shape[0]
63-
output_dims = list(output.shape[2:])
6465

6566
kernel_dims = list(conv_module.kernel_size)
6667
in_channels = conv_module.in_channels
@@ -71,7 +72,12 @@ def conv_flops_counter_hook(conv_module, input, output, extra_per_position_flops
7172
conv_per_position_flops = int(np.prod(kernel_dims, dtype=np.int64)) * \
7273
(in_channels * filters_per_channel + extra_per_position_flops)
7374

74-
active_elements_count = batch_size * int(np.prod(output_dims, dtype=np.int64))
75+
if transpose:
76+
input_dims = list(input.shape[2:])
77+
active_elements_count = batch_size * int(np.prod(input_dims, dtype=np.int64))
78+
else:
79+
output_dims = list(output.shape[2:])
80+
active_elements_count = batch_size * int(np.prod(output_dims, dtype=np.int64))
7581

7682
overall_conv_flops = conv_per_position_flops * active_elements_count
7783

@@ -301,9 +307,9 @@ def timm_attention_counter_hook(attention_module, input, output):
301307
# Upscale
302308
nn.Upsample: upsample_flops_counter_hook,
303309
# Deconvolution
304-
nn.ConvTranspose1d: conv_flops_counter_hook,
305-
nn.ConvTranspose2d: conv_flops_counter_hook,
306-
nn.ConvTranspose3d: conv_flops_counter_hook,
310+
nn.ConvTranspose1d: partial(conv_flops_counter_hook, transpose=True),
311+
nn.ConvTranspose2d: partial(conv_flops_counter_hook, transpose=True),
312+
nn.ConvTranspose3d: partial(conv_flops_counter_hook, transpose=True),
307313
# RNN
308314
nn.RNN: rnn_flops_counter_hook,
309315
nn.GRU: rnn_flops_counter_hook,

0 commit comments

Comments
 (0)