Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate distributed state dict API #2138

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions recipes/dev/early_exit_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,6 @@ def _setup_model(
model,
model_state_dict,
self._device,
self._is_rank_zero,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need to update _setup_optimizer in this recipe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_setup_optimizer does not call load_from_full_model_state_dict, did not update with removal of self._is_rank_zero in _setup_optimizer

strict=True,
cpu_offload=fsdp_cpu_offload,
)
Expand Down Expand Up @@ -757,7 +756,7 @@ def save_checkpoint(
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._model,
self._is_rank_zero,
device=self._device,
)
Expand All @@ -773,6 +772,7 @@ def save_checkpoint(
log.info("Getting optimizer state dict...")
if not self._optimizer_in_bwd:
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
Expand All @@ -781,7 +781,7 @@ def save_checkpoint(
opt_state_dict = {}
for param, opt in self._optim_ckpt_wrapper.optim_map.items():
opt_state_dict[param] = training.get_full_optimizer_state_dict(
opt, self._is_rank_zero, device=self._device
self._model, opt, self._is_rank_zero, device=self._device
)
if self._is_rank_zero:
log.info(
Expand Down
3 changes: 2 additions & 1 deletion recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,6 @@ def _setup_model(
model,
model_state_dict,
self._device,
self._is_rank_zero,
strict=True,
cpu_offload=fsdp_cpu_offload,
)
Expand Down Expand Up @@ -602,6 +601,7 @@ def _setup_optimizer(
for param in opt_state_dict.keys():
try:
training.load_from_full_optimizer_state_dict(
self._model,
self._optim_ckpt_wrapper.state_dict()[param],
opt_state_dict[param],
self._device,
Expand All @@ -617,6 +617,7 @@ def _setup_optimizer(
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
self._model,
optimizer,
opt_state_dict,
self._device,
Expand Down
7 changes: 3 additions & 4 deletions recipes/knowledge_distillation_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,6 @@ def _setup_model(
model,
lora_weights_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
else:
Expand All @@ -486,7 +485,6 @@ def _setup_model(
model,
base_model_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
for m in model.modules():
Expand Down Expand Up @@ -574,7 +572,6 @@ def _setup_teacher_model(
model,
model_state_dict,
self._device,
self._is_rank_zero,
strict=True,
cpu_offload=fsdp_cpu_offload,
)
Expand Down Expand Up @@ -611,6 +608,7 @@ def _setup_optimizer(
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
self._model,
optimizer,
opt_state_dict,
self._device,
Expand Down Expand Up @@ -705,13 +703,14 @@ def save_checkpoint(self, epoch: int) -> None:
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._model,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized we are doing things differently here than in the other recipes.. seems to me like we could move the call to get_adapter_state_dict up before calling gather_cpu_state_dict, then you could make the same changes you did in e.g. lora_finetune_distributed.py (remove the call to get_adapter_state_dict and instead just pass trainable_only=self._save_adapter_weights_only to gather_cpu_state_dict). Lmk if that makes sense to you

self._is_rank_zero,
device=self._device,
)

if intermediate_checkpoint:
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
Expand Down
11 changes: 4 additions & 7 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,6 @@ def _setup_model(
model,
lora_weights_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
else:
Expand All @@ -410,7 +409,6 @@ def _setup_model(
model,
base_model_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
is_dora = False
Expand Down Expand Up @@ -458,6 +456,7 @@ def _setup_optimizer(
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
self._model,
optimizer,
opt_state_dict,
self._device,
Expand Down Expand Up @@ -546,17 +545,15 @@ def save_checkpoint(
intermediate_checkpoint = epoch + 1 < self.total_epochs
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
state_dict = self._model.state_dict()
if self._save_adapter_weights_only:
state_dict = get_adapter_state_dict(state_dict, device=None)

cpu_state_dict = training.gather_cpu_state_dict(
state_dict,
self._model,
self._is_rank_zero,
device=self._device,
adapter_weights_only=self._save_adapter_weights_only,
)
if intermediate_checkpoint:
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
Expand Down
11 changes: 4 additions & 7 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,6 @@ def _setup_model(
model,
lora_weights_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
else:
Expand All @@ -505,7 +504,6 @@ def _setup_model(
model,
base_model_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
for m in model.modules():
Expand Down Expand Up @@ -549,6 +547,7 @@ def _setup_optimizer(
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
self._model,
optimizer,
opt_state_dict,
self._device,
Expand Down Expand Up @@ -656,14 +655,11 @@ def save_checkpoint(

# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
state_dict = self._model.state_dict()
if self._save_adapter_weights_only:
state_dict = get_adapter_state_dict(state_dict, device=None)

cpu_state_dict = training.gather_cpu_state_dict(
state_dict,
self._model,
self._is_rank_zero,
device=self._device,
adapter_weights_only=self._save_adapter_weights_only,
)
utils.log_rank_zero(
log,
Expand All @@ -673,6 +669,7 @@ def save_checkpoint(
if intermediate_checkpoint:
utils.log_rank_zero(log, "Retrieving optimizer state dict...")
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
Expand Down
2 changes: 0 additions & 2 deletions recipes/lora_finetune_distributed_multi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,6 @@ def _setup_model(
model,
lora_weights_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
else:
Expand All @@ -500,7 +499,6 @@ def _setup_model(
model,
base_model_state_dict,
self._device,
self._is_rank_zero,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Commenting here for further down in the file but) is there a reason you didn't also update save_checkpoint in this recipe? (We don't yet have a test for it so probably didn't get caught by CI)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR here only removes self._is_rank_zero from training.load_from_full_model_state_dict, which is not called at save_checkpoint, would add the test of lora_finetune_distributed_multi_dataset.py and early_exit_finetune_distributed.py later

cpu_offload=fsdp_cpu_offload,
)
for m in model.modules():
Expand Down
8 changes: 5 additions & 3 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,6 @@ def _setup_model(
model,
model_state_dict,
self._device,
self._is_rank_zero,
strict=True,
cpu_offload=fsdp_cpu_offload,
)
Expand Down Expand Up @@ -562,6 +561,7 @@ def _setup_optimizer(
for param in opt_state_dict.keys():
try:
training.load_from_full_optimizer_state_dict(
self._model,
self._optim_ckpt_wrapper.state_dict()[param],
opt_state_dict[param],
self._device,
Expand All @@ -577,6 +577,7 @@ def _setup_optimizer(
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
self._model,
optimizer,
opt_state_dict,
self._device,
Expand Down Expand Up @@ -667,7 +668,7 @@ def save_checkpoint(
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._model,
self._is_rank_zero,
device=self._device,
)
Expand All @@ -682,6 +683,7 @@ def save_checkpoint(
utils.log_rank_zero(log, "Getting optimizer state dict...")
if not self._optimizer_in_bwd:
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
Expand All @@ -690,7 +692,7 @@ def save_checkpoint(
opt_state_dict = {}
for param, opt in self._optim_ckpt_wrapper.optim_map.items():
opt_state_dict[param] = training.get_full_optimizer_state_dict(
opt, self._is_rank_zero, device=self._device
self._model, opt, self._is_rank_zero, device=self._device
)
utils.log_rank_zero(
log,
Expand Down
11 changes: 4 additions & 7 deletions recipes/qat_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,6 @@ def _setup_model(
model,
lora_weights_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
else:
Expand All @@ -550,7 +549,6 @@ def _setup_model(
model,
base_model_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
validate_missing_and_unexpected_for_lora(
Expand Down Expand Up @@ -589,6 +587,7 @@ def _setup_optimizer(
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
self._model,
optimizer,
opt_state_dict,
self._device,
Expand Down Expand Up @@ -699,14 +698,11 @@ def save_checkpoint(

# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
state_dict = self._model.state_dict()
if self._save_adapter_weights_only:
state_dict = get_adapter_state_dict(state_dict, device=None)

cpu_state_dict = training.gather_cpu_state_dict(
state_dict,
self._model,
self._is_rank_zero,
device=self._device,
adapter_weights_only=self._save_adapter_weights_only,
)
if self._is_rank_zero:
log.info(
Expand All @@ -717,6 +713,7 @@ def save_checkpoint(
if self._is_rank_zero:
log.info("Retrieving optimizer state dict...")
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
Expand Down
2 changes: 0 additions & 2 deletions tests/torchtune/modules/peft/test_dora.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,6 @@ def _test_dora_distributed_init(self, load_dora_weights):
ffn,
adapter_state_dict,
device,
is_rank_zero,
)
if is_rank_zero:
for dora_linear in [ffn.w1, ffn.w2, ffn.w3]:
Expand Down Expand Up @@ -377,7 +376,6 @@ def _test_dora_distributed_init(self, load_dora_weights):
ffn,
base_model_state_dict,
device,
is_rank_zero,
)

# After this, everything should be off meta device
Expand Down
13 changes: 5 additions & 8 deletions tests/torchtune/training/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,9 @@ def test_lora_state_dict(self):
fsdp_optim_to_save.zero_grad()
expected_model_sd = base_model.state_dict()
expected_optim_sd = base_optim.state_dict()
model_full_sd = training.gather_cpu_state_dict(
fsdp_model_to_save.state_dict(), is_rank_zero
)
model_full_sd = training.gather_cpu_state_dict(fsdp_model_to_save, is_rank_zero)
optim_full_sd = training.get_full_optimizer_state_dict(
fsdp_model_to_save,
fsdp_optim_to_save,
is_rank_zero,
)
Expand Down Expand Up @@ -222,12 +221,12 @@ def test_lora_state_dict(self):
fsdp_model_to_load,
copy.deepcopy(base_model.state_dict()),
torch.device("cuda"),
is_rank_zero,
)
fsdp_optim_to_load = torch.optim.Adam(
fsdp_model_to_load.parameters(), weight_decay=0.01, lr=0.01
)
training.load_from_full_optimizer_state_dict(
fsdp_model_to_load,
fsdp_optim_to_load,
# mimic mmap=True where every rank see full SD
copy.deepcopy(self._broadcast_full_state_dict(optim_full_sd)),
Expand Down Expand Up @@ -324,9 +323,7 @@ def _test_qlora_state_dict(self, enable_activation_checkpointing: bool):
fsdp_model_to_save(inp)

expected_model_sd = {k: v.cpu() for k, v in base_model.state_dict().items()}
model_full_sd = training.gather_cpu_state_dict(
fsdp_model_to_save.state_dict(), is_rank_zero
)
model_full_sd = training.gather_cpu_state_dict(fsdp_model_to_save, is_rank_zero)
if is_rank_zero:
self.assertEqual(set(model_full_sd.keys()), set(expected_model_sd.keys()))
for key, value in model_full_sd.items():
Expand Down Expand Up @@ -357,7 +354,7 @@ def _test_qlora_state_dict(self, enable_activation_checkpointing: bool):
fully_shard(m)
fully_shard(fsdp_model_to_load)
training.load_from_full_model_state_dict(
fsdp_model_to_load, expected_model_sd, torch.device("cuda"), is_rank_zero
fsdp_model_to_load, expected_model_sd, torch.device("cuda")
)
fsdp_model_to_load(inp)
sharded_model_sd = fsdp_model_to_load.state_dict()
Expand Down
Loading
Loading