Skip to content

Commit

Permalink
fix torch.nn.functional.interpolate
Browse files Browse the repository at this point in the history
  • Loading branch information
heheda12345 committed May 28, 2024
1 parent 41f8b3f commit 3d3ec43
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion frontend/fx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,17 @@ def eager_due_to_inductor_bug(node: torch.fx.Node) -> bool:
else:
random_number = str(random.randint(0, 1000000))
folder_name = f'tmp/fx_module_{random_number}'

for node in gm.graph.nodes:
# to avoid error like
# interpolate(Tensor input, int? size=None, float[]? scale_factor=None, str mode="nearest", bool? align_corners=None, bool? recompute_scale_factor=None, bool antialias=False) -> Tensor:
# Expected a value of type 'Optional[List[float]]' for argument 'scale_factor' but instead found type 'int'.
if node.target == torch.nn.functional.interpolate and 'scale_factor' in node.kwargs:
new_dict = {k: v for k, v in node.kwargs.items()}
new_dict['scale_factor'] = float(new_dict['scale_factor'])
node.kwargs = new_dict
print(node.kwargs)

gm.recompile()
os.makedirs(folder_name, exist_ok=True)
gm.to_folder(folder_name)

Expand Down

0 comments on commit 3d3ec43

Please sign in to comment.