Skip to content

Commit

Permalink
Fix checkpoint and train recipe loading (#78) (#79)
Browse files Browse the repository at this point in the history
  • Loading branch information
KSGulin authored Jun 29, 2022
1 parent d5807bf commit 980521b
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions utils/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,20 @@ def __init__(
self.apply_checkpoint_structure(train_mode, epoch, one_shot)

def state_dict(self, final_epoch):
if self.enabled or self.checkpoint_manager:
compose_recipes = self.checkpoint_manager and self.enabled and final_epoch
return {
'checkpoint_recipe': str(ScheduledModifierManager.compose_staged(self.checkpoint_manager, self.manager))
if compose_recipes else str(self.checkpoint_manager),
'train_recipe': str(self.manager) if not final_epoch else None
}
if self.enabled and final_epoch:
checkpoint_recipe = (
str(ScheduledModifierManager.compose_staged(self.checkpoint_manager, self.manager))
if self.checkpoint_manager else str(self.manager)
)
train_recipe = None
else:
return {
'checkpoint_recipe': None,
'train_recipe': None
}
checkpoint_recipe = str(self.checkpoint_manager) if self.checkpoint_manager else None
train_recipe = str(self.manager) if self.manager else None

return {
'checkpoint_recipe': checkpoint_recipe,
'train_recipe': train_recipe
}

def apply_checkpoint_structure(self, train_mode, epoch, one_shot=False):
if self.checkpoint_manager:
Expand Down

0 comments on commit 980521b

Please sign in to comment.