Skip to content

Commit

Permalink
Merge pull request #310 from aws-samples/instantiate_model_on_rank_0_…
Browse files Browse the repository at this point in the history
…only

Instantiate model on CPU on rank=0 only to prevent CPU OOM
  • Loading branch information
KeitaW authored May 5, 2024
2 parents 47fec1d + c93102b commit 62bbdd0
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion 3.test_cases/10.FSDP/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,13 @@ def main(args):
logger.info(
"Creating Model"
)
model = AutoModelForCausalLM.from_config(model_config)
# Instantiate model on CPU on rank=0 only to prevent CPU OOM
# (e.g. 70B * 4 bytes * 8 processes > 2T RAM available on P5)
if global_rank == 0:
model = AutoModelForCausalLM.from_config(model_config)
else:
with torch.device("meta"):
model = AutoModelForCausalLM.from_config(model_config)

num_params = compute_num_params(model)
if global_rank == 0:
Expand Down Expand Up @@ -191,6 +197,8 @@ def main(args):
device_id=torch.cuda.current_device(),
use_orig_params=False,
sharding_strategy=sharding_strategy,
param_init_fn=(lambda module: module.to_empty(device=torch.device("cuda"), recurse=False))
if global_rank != 0 else None,
)

if global_rank == 0:
Expand Down

0 comments on commit 62bbdd0

Please sign in to comment.