From 980521b73b6e1d09378220d977568ebc6eca27bc Mon Sep 17 00:00:00 2001 From: Konstantin Gulin <66528950+KSGulin@users.noreply.github.com> Date: Wed, 29 Jun 2022 21:09:39 +0100 Subject: [PATCH] Fix checkpoint and train recipe loading (#78) (#79) --- utils/sparse.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/utils/sparse.py b/utils/sparse.py index b03d44af1d41..d5cfb77b67e5 100644 --- a/utils/sparse.py +++ b/utils/sparse.py @@ -75,18 +75,20 @@ def __init__( self.apply_checkpoint_structure(train_mode, epoch, one_shot) def state_dict(self, final_epoch): - if self.enabled or self.checkpoint_manager: - compose_recipes = self.checkpoint_manager and self.enabled and final_epoch - return { - 'checkpoint_recipe': str(ScheduledModifierManager.compose_staged(self.checkpoint_manager, self.manager)) - if compose_recipes else str(self.checkpoint_manager), - 'train_recipe': str(self.manager) if not final_epoch else None - } + if self.enabled and final_epoch: + checkpoint_recipe = ( + str(ScheduledModifierManager.compose_staged(self.checkpoint_manager, self.manager)) + if self.checkpoint_manager else str(self.manager) + ) + train_recipe = None else: - return { - 'checkpoint_recipe': None, - 'train_recipe': None - } + checkpoint_recipe = str(self.checkpoint_manager) if self.checkpoint_manager else None + train_recipe = str(self.manager) if self.manager else None + + return { + 'checkpoint_recipe': checkpoint_recipe, + 'train_recipe': train_recipe + } def apply_checkpoint_structure(self, train_mode, epoch, one_shot=False): if self.checkpoint_manager: