fix: allow loading of nested states in Fabric.load
[wip]
#20210
+30
−16
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
This PR aims to allow loading nested states via
Fabric.load
(fixes #20208). Nested states can add some structure when multiple models are trained simultaneously (e.g. in GANs or in reinforcement learning).Currently, it is my understanding that the following methods need to be modified
Fabric.load
: In this method, some entries of thestate
dictionaries are replaced with those loaded from the checkpoint. Until now, only fabric-internal objects were not replaced (since the respective state dicts had been loaded previously). We need to apply this replacement procedures recursively to sub-dictionaries (instead of replacing them).Strategy._convert_stateful_objects_in_state
: This method again has to recursively walk into sub-dictionaries and turn stateful objects into their state-dicts.Strategy.load_checkpoint
and anything overriding it: Again, we need to recursively walk into the sub-dictionaries and possibly load the state-dicts of stateful objects.Strategy.save_checkpoint
: Perhaps depending on the strategy there is also some relevant code here (not the case for the abstract strategy class).Any strategy that overrides
load_checkpoint
will also have to be updated.I am not sure whether these are all the changes required to make this work, but preliminary tests from my side look okay. If someone more familiar with the codebase could give some feedback on whether this covers everything, I'd be grateful.
TODO:
Fabric.load
Strategy._convert_stateful_objects_in_state
andStrategy.load_checkpoint
load_checkpoint
andsave_checkpoint
methods of strategies overriding theStrategy
methodsBefore submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
📚 Documentation preview 📚: https://pytorch-lightning--20210.org.readthedocs.build/en/20210/