-
Notifications
You must be signed in to change notification settings - Fork 471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migrate distributed state dict API #2138
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2138
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3d0d26f with merge base 002b17c (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2138 +/- ##
==========================================
+ Coverage 9.33% 65.26% +55.93%
==========================================
Files 289 334 +45
Lines 16959 19192 +2233
==========================================
+ Hits 1583 12526 +10943
+ Misses 15376 6666 -8710 ☔ View full report in Codecov by Sentry. |
…ept 2 device type and optimize memory (#142845) For destributed state dict api [migration](pytorch/torchtune#2138), make the changes here: 1. `load_from_full_model_state_dict` at TorchTune calls `set_model_state_dict` with the options on whether to have cpu_offload. Add cpu_offload at _load_model_state_dict to process to cpu if config is True 2. Change the device check as lora_finetune might hace 2 device types, accept that to be valid. 3. Some changes to optimize the memory performance: 3.1 use `.detach().clone()` instead of view directly 3.2 if local_state is not meta, copy `full_tensor[slices]` to `ret.to_local()` 4. add relative unit tests Memory performance calling from TorchTune with llama2/7B_full: 1. cpu_offload = True <img width="555" alt="Screenshot 2024-12-18 at 1 36 47 PM" src="https://github.com/user-attachments/assets/429261f5-1107-4592-b295-de3944a2614b" /> 2. cpu_offload = False <img width="555" alt="Screenshot 2024-12-18 at 1 36 52 PM" src="https://github.com/user-attachments/assets/40bf281a-236a-4218-826b-b1192a10c806" /> Pull Request resolved: #142845 Approved by: https://github.com/fegin
sharded_param = full_tensor.new_zeros(chunk.size()) | ||
sharded_param[: chunk.size(0)].copy_(chunk) | ||
|
||
# TODO: change to from_local API (need to add view support for NF4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can we get view support for NF4?
cc @andrewor14
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the review, we currently skip the NF4 tensor part and plan to support NF4 in the next quarter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like there's already view support for NF4Tensor? What's the error you're getting?
also cc @drisspg @weifengpy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I brought this up with @ebsmothers and @gau-nernst in Discord. We thought that we needed to do anything else here, it should just be safe to just switch to from_local
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the comments, shall I switch to from_local
in this pr or get it with the other nf4 tensor support in the next pr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes if possible it'd be great to move to from_local
here assuming everything works. Imo the more that we can clean this function up the better, as is it has gotten quite unwieldy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your patience! Left a bunch of comments, please let me know if anything is unclear. One request is to also manually run lora_finetune_distributed_multi_dataset.py
and early_exit_finetune_distributed.py
recipes as they do not currently have tests in our CI. Happy to provide any pointers here if you need.
@@ -705,13 +703,14 @@ def save_checkpoint(self, epoch: int) -> None: | |||
# To prevent GPU memory from spiking during checkpoint save, | |||
# we consolidate the full model and optim state dicts on CPU for rank 0 | |||
cpu_state_dict = training.gather_cpu_state_dict( | |||
self._model.state_dict(), | |||
self._model, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just realized we are doing things differently here than in the other recipes.. seems to me like we could move the call to get_adapter_state_dict
up before calling gather_cpu_state_dict
, then you could make the same changes you did in e.g. lora_finetune_distributed.py
(remove the call to get_adapter_state_dict
and instead just pass trainable_only=self._save_adapter_weights_only
to gather_cpu_state_dict
). Lmk if that makes sense to you
self._is_rank_zero, | ||
device=self._device, | ||
trainable_only=self._save_adapter_weights_only, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also one other thing we will have to be aware of is that in general it may not always be the case that trainable params == adapter params. This holds true today, but especially for multimodal models we need to be careful because some people may want to e.g. do LoRA finetuning on the image encoder and full finetuning on the text decoder. This was disabled in #2150 but we may want to add it back later and in that case this would be misleading. So I think trainable_only
is potentially a misnomer and it may be best to rename adapter_weights_only
or something like that.
@@ -500,7 +499,6 @@ def _setup_model( | |||
model, | |||
base_model_state_dict, | |||
self._device, | |||
self._is_rank_zero, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Commenting here for further down in the file but) is there a reason you didn't also update save_checkpoint
in this recipe? (We don't yet have a test for it so probably didn't get caught by CI)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR here only removes self._is_rank_zero
from training.load_from_full_model_state_dict
, which is not called at save_checkpoint
, would add the test of lora_finetune_distributed_multi_dataset.py
and early_exit_finetune_distributed.py
later
@@ -556,7 +556,6 @@ def _setup_model( | |||
model, | |||
model_state_dict, | |||
self._device, | |||
self._is_rank_zero, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we also need to update _setup_optimizer
in this recipe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_setup_optimizer
does not call load_from_full_model_state_dict
, did not update with removal of self._is_rank_zero
in _setup_optimizer
Args: | ||
sharded_sd (Dict[str, DTensor]): Sharded state dict of DTensors | ||
model (FSDPModule): Model to generate fqn for cpu_state_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit but I don't think most people know what "fqn" means, might write something more descriptive here
) -> Dict[str, Any]: | ||
""" | ||
Converting sharded state dict into a full state dict on CPU | ||
Returning non-empty result only on rank0 to avoid peaking CPU memory | ||
TODO: add support for NF4Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add more details here so that it's clear to someone who's reading the code? Something like "If the model does not contain any NF4 tensors, we directly use distributed state dict APIs. Otherwise, we need to manually gather any NF4 tensors until all-gather is supported in the NF4Tensor subclass"
for param_name, full_tensor in full_sd.items(): | ||
sharded_meta_param = meta_sharded_sd.get(param_name) | ||
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device) | ||
if hasattr(sharded_meta_param, "_local_tensor") and isinstance( | ||
sharded_meta_param._local_tensor, NF4Tensor | ||
): | ||
block_size = sharded_meta_param._local_tensor.block_size | ||
scaler_block_size = ( | ||
sharded_meta_param._local_tensor.scaler_block_size | ||
) | ||
full_tensor = to_nf4( | ||
full_tensor, | ||
block_size=block_size, | ||
scaler_block_size=scaler_block_size, | ||
) | ||
# replicating logic from `_fsdp_param.py`` `_init_sharded_param` | ||
# otherwise `distribute_tensor(DTensor(local=NF4))` | ||
# requires dispatching `c10d.scatter_`` | ||
# long-term solution is `swap_tensor` | ||
mesh = sharded_meta_param.device_mesh | ||
if mesh.ndim > 1: | ||
raise NotImplementedError( | ||
f"only support 1D FSDP but got {mesh.ndim=}" | ||
) | ||
shard_mesh_dim = 0 | ||
shard_world_size = mesh.size(shard_mesh_dim) | ||
shard_rank = cast( | ||
torch.distributed.ProcessGroup, mesh.get_group(shard_mesh_dim) | ||
).rank() | ||
chunk = list(torch.chunk(full_tensor, shard_world_size, dim=0))[ | ||
shard_rank | ||
] | ||
sharded_param = full_tensor.new_zeros(chunk.size()) | ||
sharded_param[: chunk.size(0)].copy_(chunk) | ||
|
||
# TODO: change to from_local API (need to add view support for NF4) | ||
sharded_tensor = DTensor( | ||
local_tensor=sharded_param, | ||
spec=DTensorSpec( | ||
mesh=sharded_meta_param.device_mesh, | ||
placements=sharded_meta_param.placements, | ||
tensor_meta=TensorMeta( | ||
shape=sharded_meta_param.size(), | ||
dtype=sharded_meta_param.dtype, | ||
stride=sharded_meta_param.stride(), | ||
), | ||
), | ||
requires_grad=sharded_meta_param.requires_grad, | ||
) | ||
|
||
elif not hasattr(sharded_meta_param, "device_mesh"): | ||
# In cases where parts of the model aren't sharded, some parameters will be plain tensors | ||
sharded_tensor = full_tensor | ||
elif not hasattr(sharded_meta_param, "device_mesh"): | ||
# In cases where parts of the model aren't sharded, some parameters will be plain tensors | ||
sharded_tensor = full_tensor | ||
else: | ||
sharded_tensor = distribute_tensor( | ||
full_tensor, | ||
sharded_meta_param.device_mesh, | ||
sharded_meta_param.placements, | ||
) | ||
if cpu_offload: | ||
sharded_tensor = sharded_tensor.cpu() | ||
sharded_sd[param_name] = nn.Parameter(sharded_tensor) | ||
# choose `assign=True` since we cannot call `copy_` on meta tensor | ||
return model.load_state_dict(sharded_sd, strict=strict, assign=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a duplicate of L274-L335? I think this function is already complicated enough, if we can just use a single if/else branch to consolidate these that'd be preferable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switch the structure here to if _USE_DISTRIBUTED_STATE_DICT_API and not has_nf4
and else
# There are some changes at `set_model_state_dict` to adjust multiple devices from local_state in TorchTune, | ||
# keey version check until PyTorch changes are on stable. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand this comment
sharded_param = full_tensor.new_zeros(chunk.size()) | ||
sharded_param[: chunk.size(0)].copy_(chunk) | ||
|
||
# TODO: change to from_local API (need to add view support for NF4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes if possible it'd be great to move to from_local
here assuming everything works. Imo the more that we can clean this function up the better, as is it has gotten quite unwieldy
Context
What is the purpose of this PR? Is it to
Migrate distributed state dict APIs from torch.distributed.
Changelog
What are the changes made in this PR?
Switch to distributed state dict APIs from torch.distributed.
load_from_full_model_state_dict
<-set_model_state_dict
gather_cpu_state_dict
<-get_model_state_dict
load_from_full_optimizer_state_dict
<-set_optimizer_state_dict
get_full_optimizer_state_dict
<-get_optimizer_state_dict
To align the inputs, add model input to
get_full_optimizer_state_dict
andload_from_full_optimizer_state_dict
.Change the sharded_sd input for
gather_cpu_state_dict
to model.TODO:
nf4tensor are kept the same, remain as future work
Test plan
pytest tests/torchtune/training/test_distributed.py
pytest tests -m integration_test
We compare the running with the previous API and the new API, loss are the same in initial loading and resume from checkpoint.
We also draw the memory traces, results show that the new API won't cost mote memory peak comapred with the current ones.