Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate distributed state dict API #2138

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft

Conversation

mori360
Copy link
Contributor

@mori360 mori360 commented Dec 10, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Migrate distributed state dict APIs from torch.distributed.

Changelog

What are the changes made in this PR?

Switch to distributed state dict APIs from torch.distributed.

  • load_from_full_model_state_dict <- set_model_state_dict
  • gather_cpu_state_dict <- get_model_state_dict
  • load_from_full_optimizer_state_dict <- set_optimizer_state_dict
  • get_full_optimizer_state_dict <- get_optimizer_state_dict

To align the inputs, add model input to get_full_optimizer_state_dict and load_from_full_optimizer_state_dict.
Change the sharded_sd input for gather_cpu_state_dict to model.

TODO:
nf4tensor are kept the same, remain as future work

Test plan

pytest tests/torchtune/training/test_distributed.py
pytest tests -m integration_test

We compare the running with the previous API and the new API, loss are the same in initial loading and resume from checkpoint.

We also draw the memory traces, results show that the new API won't cost mote memory peak comapred with the current ones.
Screenshot 2025-01-02 at 1 10 18 PM

Copy link

pytorch-bot bot commented Dec 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2138

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3d0d26f with merge base 002b17c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 10, 2024
@joecummings joecummings added the distributed Anything related to distributed env (multi-GPU, multi-node) label Dec 10, 2024
@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 3.38983% with 57 lines in your changes missing coverage. Please review.

Project coverage is 65.26%. Comparing base (f2bd4bc) to head (8b575be).
Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/training/_distributed.py 3.50% 55 Missing ⚠️
tests/torchtune/training/test_distributed.py 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            main    #2138       +/-   ##
==========================================
+ Coverage   9.33%   65.26%   +55.93%     
==========================================
  Files        289      334       +45     
  Lines      16959    19192     +2233     
==========================================
+ Hits        1583    12526    +10943     
+ Misses     15376     6666     -8710     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@mori360 mori360 changed the title Mitigate distributed state dict API Migrate distributed state dict API Dec 18, 2024
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Dec 19, 2024
…ept 2 device type and optimize memory (#142845)

For destributed state dict api [migration](pytorch/torchtune#2138), make the changes here:
1. `load_from_full_model_state_dict` at TorchTune calls `set_model_state_dict` with the options on whether to have cpu_offload. Add cpu_offload at _load_model_state_dict to process to cpu if config is True
2. Change the device check as lora_finetune might hace 2 device types, accept that to be valid.
3. Some changes to optimize the memory performance:
3.1 use `.detach().clone()` instead of view directly
3.2 if local_state is not meta, copy `full_tensor[slices]` to `ret.to_local()`
4. add relative unit tests

Memory performance calling from TorchTune with llama2/7B_full:
1. cpu_offload = True
<img width="555" alt="Screenshot 2024-12-18 at 1 36 47 PM" src="https://github.com/user-attachments/assets/429261f5-1107-4592-b295-de3944a2614b" />

2. cpu_offload = False
<img width="555" alt="Screenshot 2024-12-18 at 1 36 52 PM" src="https://github.com/user-attachments/assets/40bf281a-236a-4218-826b-b1192a10c806" />

Pull Request resolved: #142845
Approved by: https://github.com/fegin
sharded_param = full_tensor.new_zeros(chunk.size())
sharded_param[: chunk.size(0)].copy_(chunk)

# TODO: change to from_local API (need to add view support for NF4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we get view support for NF4?

cc @andrewor14

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the review, we currently skip the NF4 tensor part and plan to support NF4 in the next quarter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like there's already view support for NF4Tensor? What's the error you're getting?

also cc @drisspg @weifengpy

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I brought this up with @ebsmothers and @gau-nernst in Discord. We thought that we needed to do anything else here, it should just be safe to just switch to from_local.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the comments, shall I switch to from_local in this pr or get it with the other nf4 tensor support in the next pr?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes if possible it'd be great to move to from_local here assuming everything works. Imo the more that we can clean this function up the better, as is it has gotten quite unwieldy

@mori360 mori360 marked this pull request as ready for review December 20, 2024 23:57
@mori360 mori360 requested a review from joecummings December 20, 2024 23:58
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your patience! Left a bunch of comments, please let me know if anything is unclear. One request is to also manually run lora_finetune_distributed_multi_dataset.py and early_exit_finetune_distributed.py recipes as they do not currently have tests in our CI. Happy to provide any pointers here if you need.

@@ -705,13 +703,14 @@ def save_checkpoint(self, epoch: int) -> None:
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._model,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized we are doing things differently here than in the other recipes.. seems to me like we could move the call to get_adapter_state_dict up before calling gather_cpu_state_dict, then you could make the same changes you did in e.g. lora_finetune_distributed.py (remove the call to get_adapter_state_dict and instead just pass trainable_only=self._save_adapter_weights_only to gather_cpu_state_dict). Lmk if that makes sense to you

self._is_rank_zero,
device=self._device,
trainable_only=self._save_adapter_weights_only,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also one other thing we will have to be aware of is that in general it may not always be the case that trainable params == adapter params. This holds true today, but especially for multimodal models we need to be careful because some people may want to e.g. do LoRA finetuning on the image encoder and full finetuning on the text decoder. This was disabled in #2150 but we may want to add it back later and in that case this would be misleading. So I think trainable_only is potentially a misnomer and it may be best to rename adapter_weights_only or something like that.

@@ -500,7 +499,6 @@ def _setup_model(
model,
base_model_state_dict,
self._device,
self._is_rank_zero,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Commenting here for further down in the file but) is there a reason you didn't also update save_checkpoint in this recipe? (We don't yet have a test for it so probably didn't get caught by CI)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR here only removes self._is_rank_zero from training.load_from_full_model_state_dict, which is not called at save_checkpoint, would add the test of lora_finetune_distributed_multi_dataset.py and early_exit_finetune_distributed.py later

@@ -556,7 +556,6 @@ def _setup_model(
model,
model_state_dict,
self._device,
self._is_rank_zero,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need to update _setup_optimizer in this recipe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_setup_optimizer does not call load_from_full_model_state_dict, did not update with removal of self._is_rank_zero in _setup_optimizer

Args:
sharded_sd (Dict[str, DTensor]): Sharded state dict of DTensors
model (FSDPModule): Model to generate fqn for cpu_state_dict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit but I don't think most people know what "fqn" means, might write something more descriptive here

) -> Dict[str, Any]:
"""
Converting sharded state dict into a full state dict on CPU
Returning non-empty result only on rank0 to avoid peaking CPU memory
TODO: add support for NF4Tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add more details here so that it's clear to someone who's reading the code? Something like "If the model does not contain any NF4 tensors, we directly use distributed state dict APIs. Otherwise, we need to manually gather any NF4 tensors until all-gather is supported in the NF4Tensor subclass"

torchtune/training/_distributed.py Show resolved Hide resolved
Comment on lines +194 to +257
for param_name, full_tensor in full_sd.items():
sharded_meta_param = meta_sharded_sd.get(param_name)
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device)
if hasattr(sharded_meta_param, "_local_tensor") and isinstance(
sharded_meta_param._local_tensor, NF4Tensor
):
block_size = sharded_meta_param._local_tensor.block_size
scaler_block_size = (
sharded_meta_param._local_tensor.scaler_block_size
)
full_tensor = to_nf4(
full_tensor,
block_size=block_size,
scaler_block_size=scaler_block_size,
)
# replicating logic from `_fsdp_param.py`` `_init_sharded_param`
# otherwise `distribute_tensor(DTensor(local=NF4))`
# requires dispatching `c10d.scatter_``
# long-term solution is `swap_tensor`
mesh = sharded_meta_param.device_mesh
if mesh.ndim > 1:
raise NotImplementedError(
f"only support 1D FSDP but got {mesh.ndim=}"
)
shard_mesh_dim = 0
shard_world_size = mesh.size(shard_mesh_dim)
shard_rank = cast(
torch.distributed.ProcessGroup, mesh.get_group(shard_mesh_dim)
).rank()
chunk = list(torch.chunk(full_tensor, shard_world_size, dim=0))[
shard_rank
]
sharded_param = full_tensor.new_zeros(chunk.size())
sharded_param[: chunk.size(0)].copy_(chunk)

# TODO: change to from_local API (need to add view support for NF4)
sharded_tensor = DTensor(
local_tensor=sharded_param,
spec=DTensorSpec(
mesh=sharded_meta_param.device_mesh,
placements=sharded_meta_param.placements,
tensor_meta=TensorMeta(
shape=sharded_meta_param.size(),
dtype=sharded_meta_param.dtype,
stride=sharded_meta_param.stride(),
),
),
requires_grad=sharded_meta_param.requires_grad,
)

elif not hasattr(sharded_meta_param, "device_mesh"):
# In cases where parts of the model aren't sharded, some parameters will be plain tensors
sharded_tensor = full_tensor
elif not hasattr(sharded_meta_param, "device_mesh"):
# In cases where parts of the model aren't sharded, some parameters will be plain tensors
sharded_tensor = full_tensor
else:
sharded_tensor = distribute_tensor(
full_tensor,
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
)
if cpu_offload:
sharded_tensor = sharded_tensor.cpu()
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
# choose `assign=True` since we cannot call `copy_` on meta tensor
return model.load_state_dict(sharded_sd, strict=strict, assign=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a duplicate of L274-L335? I think this function is already complicated enough, if we can just use a single if/else branch to consolidate these that'd be preferable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch the structure here to if _USE_DISTRIBUTED_STATE_DICT_API and not has_nf4 and else

Comment on lines +182 to +183
# There are some changes at `set_model_state_dict` to adjust multiple devices from local_state in TorchTune,
# keey version check until PyTorch changes are on stable.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this comment

sharded_param = full_tensor.new_zeros(chunk.size())
sharded_param[: chunk.size(0)].copy_(chunk)

# TODO: change to from_local API (need to add view support for NF4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes if possible it'd be great to move to from_local here assuming everything works. Imo the more that we can clean this function up the better, as is it has gotten quite unwieldy

@mori360 mori360 marked this pull request as draft January 6, 2025 18:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. distributed Anything related to distributed env (multi-GPU, multi-node)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants