diff --git a/export.py b/export.py index 179952b49e2..608a3edd1c1 100644 --- a/export.py +++ b/export.py @@ -483,6 +483,7 @@ def load_checkpoint( dnn=False, half = False, recipe=None, + recipe_args=None, resume=None, rank=-1, one_shot=False, @@ -543,6 +544,7 @@ def load_checkpoint( model.model if val_type else model, checkpoint_recipe, train_recipe, + recipe_args=recipe_args, train_mode=train_type, epoch=ckpt['epoch'], one_shot=one_shot, diff --git a/train.py b/train.py index 98d0a007729..90c512629d3 100644 --- a/train.py +++ b/train.py @@ -127,6 +127,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary hyp=hyp, nc=nc, recipe=opt.recipe, + recipe_args = opt.recipe_args, resume=opt.resume, rank=LOCAL_RANK, one_shot=opt.one_shot, @@ -141,6 +142,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary model, None, opt.recipe, + recipe_args=opt.recipe_args, train_mode=True, steps_per_epoch=opt.max_train_steps, one_shot=opt.one_shot, @@ -588,6 +590,9 @@ def parse_opt(known=False, skip_parse=False): parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use') parser.add_argument('--recipe', type=str, default=None, help='Path to a sparsification recipe, ' 'see https://github.com/neuralmagic/sparseml for more information') + parser.add_argument("--recipe-args", type=str, default=None, help = 'A json string, csv key=value string, or dictionary ' + 'containing arguments to override the root arguments ' + 'within the recipe such as learning rate or num epochs') parser.add_argument('--disable-ema', action='store_true', help='Disable EMA model updates (enabled by default)') parser.add_argument("--max-train-steps", type=int, default=-1, help="Set the maximum number of training steps per epoch. if negative," diff --git a/utils/sparse.py b/utils/sparse.py index 02b59447dd3..93fa6913a5f 100644 --- a/utils/sparse.py +++ b/utils/sparse.py @@ -56,6 +56,7 @@ def __init__( model, checkpoint_recipe, train_recipe, + recipe_args = None, train_mode=False, epoch=-1, steps_per_epoch=-1, @@ -64,7 +65,7 @@ def __init__( self.enabled = bool(train_recipe) self.model = model.module if is_parallel(model) else model self.checkpoint_manager = ScheduledModifierManager.from_yaml(checkpoint_recipe) if checkpoint_recipe else None - self.manager = ScheduledModifierManager.from_yaml(train_recipe) if train_recipe else None + self.manager = ScheduledModifierManager.from_yaml(train_recipe, recipe_variables=recipe_args) if train_recipe else None self.logger = None self.start_epoch = None self.steps_per_epoch = steps_per_epoch