Skip to content

Commit

Permalink
fix mp policy forwarding in fsdp2 (#970)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #970

Reviewed By: galrotem, anshulverma

Differential Revision: D69669442

fbshipit-source-id: fff1e475ab1a31fc3291ee249c21292af1fc0561
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Feb 14, 2025
1 parent 5bc1702 commit 7390c77
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tests/utils/test_prepare_module_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def _test_prepare_fsdp2_shard_all() -> None:

module = SimpleModule()
device = torch.device("cuda")
strategy = FSDP2Strategy(modules_to_shard="all")
strategy = FSDP2Strategy(modules_to_shard="all", mp_policy=torch.bfloat16)
prepare_fsdp2(module, device, strategy)

for submodule in module.modules():
Expand Down
4 changes: 2 additions & 2 deletions torchtnt/utils/prepare_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,9 @@ def prepare_fsdp2(
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()
if (mp_policy := strategy.mp_policy) is not None:
if isinstance(mp_policy, MixedPrecisionPolicy):
fsdp_kwargs["mixed_precision"] = mp_policy
fsdp_kwargs["mp_policy"] = mp_policy
else:
fsdp_kwargs["mixed_precision"] = MixedPrecisionPolicy(
fsdp_kwargs["mp_policy"] = MixedPrecisionPolicy(
param_dtype=mp_policy,
reduce_dtype=mp_policy,
output_dtype=mp_policy,
Expand Down

0 comments on commit 7390c77

Please sign in to comment.