Skip to content

Commit 994faeb

Browse files
committed
Finally fix transposed conv
1 parent 93311b2 commit 994faeb

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

ptflops/pytorch_ops.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,7 @@ def conv_flops_counter_hook(conv_module, input, output, extra_per_position_flops
8585
bias_flops = 0
8686

8787
if conv_module.bias is not None:
88-
89-
bias_flops = out_channels * active_elements_count
88+
bias_flops = batch_size * int(np.prod(list(output.shape[1:]), dtype=np.int64))
9089

9190
overall_flops = overall_conv_flops + bias_flops
9291

tests/common_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,17 @@ def test_conv(self, default_input_image_size, backend: FLOPS_BACKEND):
3333
assert params == 3 * 3 * 2 * 3 + 2
3434
assert macs == 2759904
3535

36+
@pytest.mark.parametrize("backend", [FLOPS_BACKEND.PYTORCH, FLOPS_BACKEND.ATEN])
37+
def test_conv_t(self, default_input_image_size, backend: FLOPS_BACKEND):
38+
net = nn.ConvTranspose2d(3, 2, 3, stride=(2, 2), bias=True)
39+
macs, params = get_model_complexity_info(net, default_input_image_size,
40+
as_strings=False,
41+
print_per_layer_stat=False,
42+
backend=backend)
43+
44+
assert params == 3 * 3 * 2 * 3 + 2
45+
assert macs == 3112706
46+
3647
@pytest.mark.parametrize("backend", [FLOPS_BACKEND.PYTORCH, FLOPS_BACKEND.ATEN])
3748
def test_fc(self, backend: FLOPS_BACKEND):
3849
net = nn.Sequential(nn.Linear(3, 2, bias=True))

0 commit comments

Comments
 (0)