From 2c85b93dc2838776779d44431652170baf47df2f Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Sun, 26 Nov 2023 09:34:05 +0900 Subject: [PATCH] Fix imprecise bias calculation in linear hook --- ptflops/pytorch_ops.py | 6 ++++-- tests/common_test.py | 9 +++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/ptflops/pytorch_ops.py b/ptflops/pytorch_ops.py index 9f079da..3801070 100644 --- a/ptflops/pytorch_ops.py +++ b/ptflops/pytorch_ops.py @@ -34,9 +34,11 @@ def linear_flops_counter_hook(module, input, output): input = input[0] # pytorch checks dimensions, so here we don't care much output_last_dim = output.shape[-1] + input_last_dim = input.shape[-1] + pre_last_dims_prod = np.prod(input.shape[0:-1], dtype=np.int64) bias_flops = output_last_dim if module.bias is not None else 0 - module.__flops__ += int(np.prod(input.shape, dtype=np.int64) * - output_last_dim + bias_flops) + module.__flops__ += int((input_last_dim * output_last_dim + bias_flops) + * pre_last_dims_prod) def pool_flops_counter_hook(module, input, output): diff --git a/tests/common_test.py b/tests/common_test.py index a426e16..9a89def 100644 --- a/tests/common_test.py +++ b/tests/common_test.py @@ -28,6 +28,15 @@ def test_fc(self): assert params == 3 * 2 + 2 assert int(macs) == 8 + def test_fc_multidim(self): + net = nn.Sequential(nn.Linear(3, 2, bias=True)) + macs, params = get_model_complexity_info(net, (4, 5, 3), + as_strings=False, + print_per_layer_stat=False) + + assert params == (3 * 2 + 2) + assert int(macs) == (3 * 2 + 2) * 4 * 5 + def test_input_constructor_tensor(self): net = nn.Sequential(nn.Linear(3, 2, bias=True))