From 5bf0c285e74ce8e40bf1e5f4c6a4b1ec70b188c4 Mon Sep 17 00:00:00 2001 From: Pavel Belevich Date: Fri, 10 May 2024 22:13:42 +0000 Subject: [PATCH] FSDP with meta device requires sync_module_states=True --- 3.test_cases/10.FSDP/train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/3.test_cases/10.FSDP/train.py b/3.test_cases/10.FSDP/train.py index 002a08c2..3c724e76 100644 --- a/3.test_cases/10.FSDP/train.py +++ b/3.test_cases/10.FSDP/train.py @@ -161,6 +161,9 @@ def main(args): model = AutoModelForCausalLM.from_config(model_config) else: with torch.device("meta"): + # Instantiating model on `meta` device doesn't consume CPU memory, + # but requires specifing `param_init_fn=...` + # and `sync_module_states=True` in FSDP c-tor. model = AutoModelForCausalLM.from_config(model_config) num_params = compute_num_params(model) @@ -197,6 +200,7 @@ def main(args): device_id=torch.cuda.current_device(), use_orig_params=False, sharding_strategy=sharding_strategy, + sync_module_states=True, param_init_fn=(lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)) if global_rank != 0 else None, )