From 7390c772eb5c74967dd2a9874091ac4db1cc3127 Mon Sep 17 00:00:00 2001 From: Jason Senthil Date: Fri, 14 Feb 2025 12:36:30 -0800 Subject: [PATCH] fix mp policy forwarding in fsdp2 (#970) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/970 Reviewed By: galrotem, anshulverma Differential Revision: D69669442 fbshipit-source-id: fff1e475ab1a31fc3291ee249c21292af1fc0561 --- tests/utils/test_prepare_module_gpu.py | 2 +- torchtnt/utils/prepare_module.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/utils/test_prepare_module_gpu.py b/tests/utils/test_prepare_module_gpu.py index ba3583f6e2..384b846f87 100644 --- a/tests/utils/test_prepare_module_gpu.py +++ b/tests/utils/test_prepare_module_gpu.py @@ -348,7 +348,7 @@ def _test_prepare_fsdp2_shard_all() -> None: module = SimpleModule() device = torch.device("cuda") - strategy = FSDP2Strategy(modules_to_shard="all") + strategy = FSDP2Strategy(modules_to_shard="all", mp_policy=torch.bfloat16) prepare_fsdp2(module, device, strategy) for submodule in module.modules(): diff --git a/torchtnt/utils/prepare_module.py b/torchtnt/utils/prepare_module.py index ca40f87b1b..a2b286cd8e 100644 --- a/torchtnt/utils/prepare_module.py +++ b/torchtnt/utils/prepare_module.py @@ -371,9 +371,9 @@ def prepare_fsdp2( fsdp_kwargs["offload_policy"] = CPUOffloadPolicy() if (mp_policy := strategy.mp_policy) is not None: if isinstance(mp_policy, MixedPrecisionPolicy): - fsdp_kwargs["mixed_precision"] = mp_policy + fsdp_kwargs["mp_policy"] = mp_policy else: - fsdp_kwargs["mixed_precision"] = MixedPrecisionPolicy( + fsdp_kwargs["mp_policy"] = MixedPrecisionPolicy( param_dtype=mp_policy, reduce_dtype=mp_policy, output_dtype=mp_policy,