Skip to content

Commit

Permalink
Fix --resume keyword and recipe loading (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
KSGulin authored Apr 14, 2022
1 parent b6974c9 commit 2aae700
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
6 changes: 3 additions & 3 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,9 +500,9 @@ def load_checkpoint(
# load sparseml recipe for applying pruning and quantization
checkpoint_recipe = train_recipe = None
if resume:
train_recipe = ckpt['recipe'] if ('recipe' in ckpt) else None
elif ckpt['recipe'] or recipe:
train_recipe, checkpoint_recipe = recipe, ckpt['recipe']
train_recipe = ckpt.get('recipe')
elif recipe or ckpt.get('recipe'):
train_recipe, checkpoint_recipe = recipe, ckpt.get('recipe')

sparseml_wrapper = SparseMLWrapper(model.model if val_type else model, checkpoint_recipe, train_recipe)
exclude_anchors = train_type and (cfg or hyp.get('anchors')) and not resume
Expand Down
14 changes: 9 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Save run settings
if not evolve:
with open(save_dir / 'hyp.yaml', 'w') as f:
yaml.dump(hyp, f, sort_keys=False)
yaml.safe_dump(hyp, f, sort_keys=False)
with open(save_dir / 'opt.yaml', 'w') as f:
yaml.dump(vars(opt), f, sort_keys=False)
yaml.safe_dump(vars(opt), f, sort_keys=False)

# Loggers
data_dict = None
Expand Down Expand Up @@ -492,7 +492,11 @@ def parse_opt(known=False):
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
parser.add_argument('--rect', action='store_true', help='rectangular training')
parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
parser.add_argument('--resume',
nargs='?',
const=True,
default=False,
help='resume most recent training. When true, ignores --recipe arg and re-uses saved recipe (if exists)')
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
parser.add_argument('--noval', action='store_true', help='only validate final epoch')
parser.add_argument('--noautoanchor', action='store_true', help='disable AutoAnchor')
Expand Down Expand Up @@ -542,7 +546,7 @@ def main(opt, callbacks=Callbacks()):
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
with open(Path(ckpt).parent.parent / 'opt.yaml', errors='ignore') as f:
opt = argparse.Namespace(**yaml.load(f, Loader=yaml.SafeLoader)) # replace
opt = argparse.Namespace(**yaml.safe_load(f)) # replace
opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
LOGGER.info(f'Resuming training from {ckpt}')
else:
Expand All @@ -553,7 +557,7 @@ def main(opt, callbacks=Callbacks()):
if opt.project == str(ROOT / 'runs/train'): # if default project name, rename to runs/evolve
opt.project = str(ROOT / 'runs/evolve')
opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume
opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok) # increment run
opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run

# DDP mode
device = select_device(opt.device, batch_size=opt.batch_size)
Expand Down
3 changes: 0 additions & 3 deletions utils/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,6 @@ def state_dict(self):
}

def apply_checkpoint_structure(self):
if not self.enabled:
return

if self.checkpoint_manager:
self.checkpoint_manager.apply_structure(self.model, math.inf)

Expand Down

0 comments on commit 2aae700

Please sign in to comment.