Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proper prefix handling in EarlyFusion sd hooks #2291

Merged
merged 1 commit into from
Jan 22, 2025

Conversation

ebsmothers
Copy link
Contributor

Thanks to @mkrainin for pointing this out and making the initial fix.

Our EarlyFusion state dict hooks don't work when an EarlyFusion module is wrapped as part of a larger model (as is the case with e.g. DDP). This is because state dict keys in the hooks have global key names, so we can't just use e.g. del state_dict["decoder.tok_embeddings.weight"] since decoder may not be a top-level module.

The fix is to use the prefix arg to state dict hooks, which gets recursively prepended while traversing the module tree (ref).

Test plan

Added a unit test wrapping EarlyFusion into an nn.Sequential so that the state dict keys have integer prefixes. This plus the existing test of state dict hooks pass after these changes:

pytest tests/torchtune/modules/model_fusion/test_early_fusion.py
...
================== 12 passed in 0.18s =============

@ebsmothers ebsmothers requested a review from RdoubleA January 22, 2025 19:17
Copy link

pytorch-bot bot commented Jan 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2291

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d470857 with merge base 890deab (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 22, 2025
@codecov-commenter
Copy link

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 66.64%. Comparing base (97e857f) to head (d470857).
Report is 10 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##            main    #2291       +/-   ##
==========================================
+ Coverage   8.97%   66.64%   +57.66%     
==========================================
  Files        305      353       +48     
  Lines      18166    20651     +2485     
==========================================
+ Hits        1631    13762    +12131     
+ Misses     16535     6889     -9646     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good find with the prefix argument, do we need to update DeepFusion as well?

"""
Keep tok_embeddings inside of decoder state_dict
[!Note] This update changes the order of the OrderedDict
"""
for n, p in module.tok_embeddings.named_parameters():
state_dict[f"decoder.tok_embeddings.{n}"] = p
del state_dict[f"tok_embeddings.{n}"]
state_dict[f"{prefix}decoder.tok_embeddings.{n}"] = p
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefix will include the period?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, see here

@ebsmothers
Copy link
Contributor Author

@RdoubleA DeepFusion should be good. State dict hooks are not defined at the model level there, but at least individual layers seem to already be using the prefix argument in a similar fashion. E.g.

@ebsmothers ebsmothers merged commit d7afc40 into pytorch:main Jan 22, 2025
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants