Skip to content

Commit

Permalink
Add recipe args support (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
KSGulin authored Sep 29, 2022
1 parent 34b6226 commit 64538d7
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 1 deletion.
2 changes: 2 additions & 0 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ def load_checkpoint(
dnn=False,
half = False,
recipe=None,
recipe_args=None,
resume=None,
rank=-1,
one_shot=False,
Expand Down Expand Up @@ -543,6 +544,7 @@ def load_checkpoint(
model.model if val_type else model,
checkpoint_recipe,
train_recipe,
recipe_args=recipe_args,
train_mode=train_type,
epoch=ckpt['epoch'],
one_shot=one_shot,
Expand Down
5 changes: 5 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
hyp=hyp,
nc=nc,
recipe=opt.recipe,
recipe_args = opt.recipe_args,
resume=opt.resume,
rank=LOCAL_RANK,
one_shot=opt.one_shot,
Expand All @@ -141,6 +142,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
model,
None,
opt.recipe,
recipe_args=opt.recipe_args,
train_mode=True,
steps_per_epoch=opt.max_train_steps,
one_shot=opt.one_shot,
Expand Down Expand Up @@ -588,6 +590,9 @@ def parse_opt(known=False, skip_parse=False):
parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use')
parser.add_argument('--recipe', type=str, default=None, help='Path to a sparsification recipe, '
'see https://github.com/neuralmagic/sparseml for more information')
parser.add_argument("--recipe-args", type=str, default=None, help = 'A json string, csv key=value string, or dictionary '
'containing arguments to override the root arguments '
'within the recipe such as learning rate or num epochs')
parser.add_argument('--disable-ema', action='store_true', help='Disable EMA model updates (enabled by default)')

parser.add_argument("--max-train-steps", type=int, default=-1, help="Set the maximum number of training steps per epoch. if negative,"
Expand Down
3 changes: 2 additions & 1 deletion utils/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
model,
checkpoint_recipe,
train_recipe,
recipe_args = None,
train_mode=False,
epoch=-1,
steps_per_epoch=-1,
Expand All @@ -64,7 +65,7 @@ def __init__(
self.enabled = bool(train_recipe)
self.model = model.module if is_parallel(model) else model
self.checkpoint_manager = ScheduledModifierManager.from_yaml(checkpoint_recipe) if checkpoint_recipe else None
self.manager = ScheduledModifierManager.from_yaml(train_recipe) if train_recipe else None
self.manager = ScheduledModifierManager.from_yaml(train_recipe, recipe_variables=recipe_args) if train_recipe else None
self.logger = None
self.start_epoch = None
self.steps_per_epoch = steps_per_epoch
Expand Down

0 comments on commit 64538d7

Please sign in to comment.