diff --git a/tests/framework/test_loop_utils.py b/tests/framework/test_loop_utils.py index c7ad1febac..613d037626 100644 --- a/tests/framework/test_loop_utils.py +++ b/tests/framework/test_loop_utils.py @@ -105,7 +105,7 @@ def forward(self, x): return x loss_fn = nn.CrossEntropyLoss() - module = torch.export.export(M(), (torch.rand(4, 4),)).module() + module = torch.export.export(M(), (torch.rand(4, 4),), strict=True).module() tracked_modules: Dict[str, torch.nn.Module] = { "module": module,