-
Notifications
You must be signed in to change notification settings - Fork 496
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Proper prefix handling in EarlyFusion sd hooks #2291
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2291
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit d470857 with merge base 890deab (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #2291 +/- ##
==========================================
+ Coverage 8.97% 66.64% +57.66%
==========================================
Files 305 353 +48
Lines 18166 20651 +2485
==========================================
+ Hits 1631 13762 +12131
+ Misses 16535 6889 -9646 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good find with the prefix argument, do we need to update DeepFusion as well?
""" | ||
Keep tok_embeddings inside of decoder state_dict | ||
[!Note] This update changes the order of the OrderedDict | ||
""" | ||
for n, p in module.tok_embeddings.named_parameters(): | ||
state_dict[f"decoder.tok_embeddings.{n}"] = p | ||
del state_dict[f"tok_embeddings.{n}"] | ||
state_dict[f"{prefix}decoder.tok_embeddings.{n}"] = p |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prefix will include the period?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, see here
Thanks to @mkrainin for pointing this out and making the initial fix.
Our
EarlyFusion
state dict hooks don't work when anEarlyFusion
module is wrapped as part of a larger model (as is the case with e.g. DDP). This is because state dict keys in the hooks have global key names, so we can't just use e.g.del state_dict["decoder.tok_embeddings.weight"]
sincedecoder
may not be a top-level module.The fix is to use the
prefix
arg to state dict hooks, which gets recursively prepended while traversing the module tree (ref).Test plan
Added a unit test wrapping
EarlyFusion
into annn.Sequential
so that the state dict keys have integer prefixes. This plus the existing test of state dict hooks pass after these changes: