diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index 1f73da026..7b5c1c1a9 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -93,10 +93,11 @@ def oneshot(**kwargs): CLI entrypoint for running oneshot calibration """ # TODO: Get rid of training args when Oneshot refactor comes in - model_args, data_args, recipe_args, training_args, _ = parse_args( + model_args, data_args, recipe_args, training_args, output_dir = parse_args( include_training_args=True, **kwargs ) training_args.do_oneshot = True + training_args.output_dir = output_dir main(model_args, data_args, recipe_args, training_args) @@ -162,6 +163,20 @@ def parse_args(include_training_args: bool = False, **kwargs): parser = HfArgumentParser((ModelArguments, DatasetArguments, RecipeArguments)) if not kwargs: + + def _get_output_dir_from_argv() -> Optional[str]: + import sys + + output_dir = None + if "--output_dir" in sys.argv: + index = sys.argv.index("--output_dir") + sys.argv.pop(index) + if index < len(sys.argv): # Check if value exists afer the flag + output_dir = sys.argv.pop(index) + + return output_dir + + output_dir = _get_output_dir_from_argv() or output_dir parsed_args = parser.parse_args_into_dataclasses() else: parsed_args = parser.parse_dict(kwargs) @@ -478,6 +493,7 @@ def main( stage_runner.predict() # save if model was provided as a string or custom output_dir was set + if isinstance(model_args.model, str) or ( training_args.output_dir != TrainingArguments.__dataclass_fields__["output_dir"].default diff --git a/tests/llmcompressor/transformers/oneshot/test_cli.py b/tests/llmcompressor/transformers/oneshot/test_cli.py index 5780ca46f..effb7ff7f 100644 --- a/tests/llmcompressor/transformers/oneshot/test_cli.py +++ b/tests/llmcompressor/transformers/oneshot/test_cli.py @@ -48,6 +48,9 @@ def test_one_shot_cli(self): if len(self.additional_args) > 0: cmd.extend(self.additional_args) + + print(" ".join(cmd)) + res = run_cli_command(cmd) self.assertEqual(res.returncode, 0) print(res.stdout)