Skip to content

Commit

Permalink
disable detect anomaly if torch compile enabled (#961)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #961

Reviewed By: diego-urgell

Differential Revision: D68336316
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Jan 17, 2025
1 parent 52b5568 commit 3f6261c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
11 changes: 10 additions & 1 deletion tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from torchtnt.utils.distributed import spawn_multi_process
from torchtnt.utils.env import init_from_env
from torchtnt.utils.lr_scheduler import TLRScheduler
from torchtnt.utils.prepare_module import DDPStrategy, FSDPStrategy
from torchtnt.utils.prepare_module import DDPStrategy, FSDPStrategy, TorchCompileParams
from torchtnt.utils.progress import Progress
from torchtnt.utils.swa import _AVERAGED_MODEL_AVAIL
from torchtnt.utils.test_utils import skip_if_not_distributed
Expand Down Expand Up @@ -741,6 +741,15 @@ def test_enable_prefetch(self) -> None:
_ = auto_unit._get_next_batch(get_dummy_train_state(), iter(data))
self.assertIsNone(auto_unit._phase_to_next_batch[ActivePhase.TRAIN])

def test_detect_anomaly_disabled_with_torch_compile(self) -> None:
auto_unit = DummyAutoUnit(
module=torch.nn.Linear(2, 2),
detect_anomaly=True,
torch_compile_params=TorchCompileParams(),
)

self.assertIsNone(auto_unit.detect_anomaly)


Batch = Tuple[torch.Tensor, torch.Tensor]

Expand Down
5 changes: 5 additions & 0 deletions torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ def __init__(
)

self.detect_anomaly = detect_anomaly
if torch_compile_params is not None:
# torch compile is not compatible with detect anomaly
# so we disable detect anomaly if torch compile is enabled
self.detect_anomaly = None
_logger.warning("torch.compile is enabled, so detect_anomaly is disabled")

# create autocast context based on precision and device type
self.maybe_autocast_precision = torch.autocast(
Expand Down

0 comments on commit 3f6261c

Please sign in to comment.