From 9e6008d7205f84dc1eefdbbd74651d6eb7041830 Mon Sep 17 00:00:00 2001 From: SuperDong <16302010007@fudan.edu.cn> Date: Mon, 3 Jun 2024 19:52:51 +0800 Subject: [PATCH] fix tensor_to_int in avg_pool2d --- frontend/guard_tracker.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/frontend/guard_tracker.py b/frontend/guard_tracker.py index 6020a65..17274bf 100644 --- a/frontend/guard_tracker.py +++ b/frontend/guard_tracker.py @@ -267,6 +267,18 @@ def record_function(self, func = torch._C._set_grad_enabled kwargs = {} pargs, pkwargs = self.as_node_args_kwargs(args, kwargs) + if func == torch.nn.functional.avg_pool2d: + # avg_pool2d only supports integer or tuple(with two int values) as inputs + if isinstance(pargs[1], tuple): + for i in pargs[1]: + if isinstance(i, torch.fx.Node): + raise ValueError("cannot convert tensor in avg_pool2d") + if isinstance(args[1], tuple): + for i in args[1]: + if torch.is_tensor(i): + raise ValueError("cannot convert tensor in avg_pool2d") + elif torch.is_tensor(args[1]): + raise ValueError("cannot convert tensor in avg_pool2d") if func in fx_graph_inplace_functions: scalar = None node = None