Skip to content

Commit

Permalink
Merge pull request #327 from aws-samples/fsdp_sync_module_states_true
Browse files Browse the repository at this point in the history
FSDP with meta device requires sync_module_states=True
  • Loading branch information
perifaws authored May 10, 2024
2 parents ecfbb50 + 5bf0c28 commit 48c339f
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions 3.test_cases/10.FSDP/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ def main(args):
model = AutoModelForCausalLM.from_config(model_config)
else:
with torch.device("meta"):
# Instantiating model on `meta` device doesn't consume CPU memory,
# but requires specifing `param_init_fn=...`
# and `sync_module_states=True` in FSDP c-tor.
model = AutoModelForCausalLM.from_config(model_config)

num_params = compute_num_params(model)
Expand Down Expand Up @@ -197,6 +200,7 @@ def main(args):
device_id=torch.cuda.current_device(),
use_orig_params=False,
sharding_strategy=sharding_strategy,
sync_module_states=True,
param_init_fn=(lambda module: module.to_empty(device=torch.device("cuda"), recurse=False))
if global_rank != 0 else None,
)
Expand Down

0 comments on commit 48c339f

Please sign in to comment.