diff --git a/utils/sparse.py b/utils/sparse.py index b03d44af1d41..d5cfb77b67e5 100644 --- a/utils/sparse.py +++ b/utils/sparse.py @@ -75,18 +75,20 @@ def __init__( self.apply_checkpoint_structure(train_mode, epoch, one_shot) def state_dict(self, final_epoch): - if self.enabled or self.checkpoint_manager: - compose_recipes = self.checkpoint_manager and self.enabled and final_epoch - return { - 'checkpoint_recipe': str(ScheduledModifierManager.compose_staged(self.checkpoint_manager, self.manager)) - if compose_recipes else str(self.checkpoint_manager), - 'train_recipe': str(self.manager) if not final_epoch else None - } + if self.enabled and final_epoch: + checkpoint_recipe = ( + str(ScheduledModifierManager.compose_staged(self.checkpoint_manager, self.manager)) + if self.checkpoint_manager else str(self.manager) + ) + train_recipe = None else: - return { - 'checkpoint_recipe': None, - 'train_recipe': None - } + checkpoint_recipe = str(self.checkpoint_manager) if self.checkpoint_manager else None + train_recipe = str(self.manager) if self.manager else None + + return { + 'checkpoint_recipe': checkpoint_recipe, + 'train_recipe': train_recipe + } def apply_checkpoint_structure(self, train_mode, epoch, one_shot=False): if self.checkpoint_manager: