-
Notifications
You must be signed in to change notification settings - Fork 472
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migrate distributed state dict API #2138
base: main
Are you sure you want to change the base?
Changes from all commits
de72b9c
4b959f0
2ad64d2
19cf431
7e11c4f
b22c2e8
851fcf8
9409ee8
581a65c
1b0ba78
ca234c7
9cd28a9
bfc7668
f368b96
eda1b9f
db0eeed
91e9818
11e24f1
3d0d26f
e4730de
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -461,7 +461,6 @@ def _setup_model( | |
model, | ||
lora_weights_state_dict, | ||
self._device, | ||
self._is_rank_zero, | ||
cpu_offload=fsdp_cpu_offload, | ||
) | ||
else: | ||
|
@@ -486,7 +485,6 @@ def _setup_model( | |
model, | ||
base_model_state_dict, | ||
self._device, | ||
self._is_rank_zero, | ||
cpu_offload=fsdp_cpu_offload, | ||
) | ||
for m in model.modules(): | ||
|
@@ -574,7 +572,6 @@ def _setup_teacher_model( | |
model, | ||
model_state_dict, | ||
self._device, | ||
self._is_rank_zero, | ||
strict=True, | ||
cpu_offload=fsdp_cpu_offload, | ||
) | ||
|
@@ -611,6 +608,7 @@ def _setup_optimizer( | |
optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) | ||
if opt_state_dict: | ||
training.load_from_full_optimizer_state_dict( | ||
self._model, | ||
optimizer, | ||
opt_state_dict, | ||
self._device, | ||
|
@@ -705,13 +703,14 @@ def save_checkpoint(self, epoch: int) -> None: | |
# To prevent GPU memory from spiking during checkpoint save, | ||
# we consolidate the full model and optim state dicts on CPU for rank 0 | ||
cpu_state_dict = training.gather_cpu_state_dict( | ||
self._model.state_dict(), | ||
self._model, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just realized we are doing things differently here than in the other recipes.. seems to me like we could move the call to |
||
self._is_rank_zero, | ||
device=self._device, | ||
) | ||
|
||
if intermediate_checkpoint: | ||
opt_state_dict = training.get_full_optimizer_state_dict( | ||
self._model, | ||
self._optimizer, | ||
self._is_rank_zero, | ||
device=self._device, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -473,7 +473,6 @@ def _setup_model( | |
model, | ||
lora_weights_state_dict, | ||
self._device, | ||
self._is_rank_zero, | ||
cpu_offload=fsdp_cpu_offload, | ||
) | ||
else: | ||
|
@@ -500,7 +499,6 @@ def _setup_model( | |
model, | ||
base_model_state_dict, | ||
self._device, | ||
self._is_rank_zero, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Commenting here for further down in the file but) is there a reason you didn't also update There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The PR here only removes |
||
cpu_offload=fsdp_cpu_offload, | ||
) | ||
for m in model.modules(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we also need to update
_setup_optimizer
in this recipe?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_setup_optimizer
does not callload_from_full_model_state_dict
, did not update with removal ofself._is_rank_zero
in_setup_optimizer