Skip to content

Commit

Permalink
Pass in strict flag to allow non-strict state_dict loading
Browse files Browse the repository at this point in the history
Summary:
This allows users to ignore keys in the state_dict that aren't part of
the given module, or part of the module that aren't in the state_dict.

The default value for the flag (true) keeps the status quo and what the pytorch
interface uses.

Differential Revision: D53066198
  • Loading branch information
schwarzmx authored and facebook-github-bot committed Jan 25, 2024
1 parent 466a0cd commit 77da550
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tests/framework/callbacks/test_torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def state_dict(self) -> Dict[str, Any]:
self.state_dict_call_count += 1
return {}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True) -> None:
self.load_state_dict_call_count += 1
return None

Expand Down

0 comments on commit 77da550

Please sign in to comment.