diff --git a/frontend/guard_tracker.py b/frontend/guard_tracker.py index 6020a65..17274bf 100644 --- a/frontend/guard_tracker.py +++ b/frontend/guard_tracker.py @@ -267,6 +267,18 @@ def record_function(self, func = torch._C._set_grad_enabled kwargs = {} pargs, pkwargs = self.as_node_args_kwargs(args, kwargs) + if func == torch.nn.functional.avg_pool2d: + # avg_pool2d only supports integer or tuple(with two int values) as inputs + if isinstance(pargs[1], tuple): + for i in pargs[1]: + if isinstance(i, torch.fx.Node): + raise ValueError("cannot convert tensor in avg_pool2d") + if isinstance(args[1], tuple): + for i in args[1]: + if torch.is_tensor(i): + raise ValueError("cannot convert tensor in avg_pool2d") + elif torch.is_tensor(args[1]): + raise ValueError("cannot convert tensor in avg_pool2d") if func in fx_graph_inplace_functions: scalar = None node = None