Skip to content

Commit

Permalink
adjust dcp api using
Browse files Browse the repository at this point in the history
  • Loading branch information
mori360 committed Dec 13, 2024
1 parent b0590d4 commit d913398
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions torchtune/training/checkpointing/_checkpoint_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def _save_checkpoint_sync(
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
model_state_dict = training.gather_cpu_state_dict(
model.state_dict(),
model,
self._is_rank_zero,
device=self._device,
)
Expand All @@ -208,6 +208,7 @@ def _save_checkpoint_sync(
if no_dist:
if not self._optimizer_in_bwd:
optim_state_dict = training.get_full_optimizer_state_dict(
model,
optimizer,
self._is_rank_zero,
device=self._device,
Expand All @@ -217,7 +218,7 @@ def _save_checkpoint_sync(
optim_state_dict[
param
] = training.get_full_optimizer_state_dict(
opt, self._is_rank_zero, device=self._device
model, opt, self._is_rank_zero, device=self._device
)
else:
optim_state_dict = optimizer.state_dict()
Expand Down

0 comments on commit d913398

Please sign in to comment.