Skip to content

Commit

Permalink
pass cli tests
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Jan 28, 2025
1 parent 44c67d7 commit 7bb2e9a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
18 changes: 17 additions & 1 deletion src/llmcompressor/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,11 @@ def oneshot(**kwargs):
CLI entrypoint for running oneshot calibration
"""
# TODO: Get rid of training args when Oneshot refactor comes in
model_args, data_args, recipe_args, training_args, _ = parse_args(
model_args, data_args, recipe_args, training_args, output_dir = parse_args(
include_training_args=True, **kwargs
)
training_args.do_oneshot = True
training_args.output_dir = output_dir
main(model_args, data_args, recipe_args, training_args)


Expand Down Expand Up @@ -162,6 +163,20 @@ def parse_args(include_training_args: bool = False, **kwargs):
parser = HfArgumentParser((ModelArguments, DatasetArguments, RecipeArguments))

if not kwargs:

def _get_output_dir_from_argv() -> Optional[str]:
import sys

output_dir = None
if "--output_dir" in sys.argv:
index = sys.argv.index("--output_dir")
sys.argv.pop(index)
if index < len(sys.argv): # Check if value exists afer the flag
output_dir = sys.argv.pop(index)

return output_dir

output_dir = _get_output_dir_from_argv() or output_dir
parsed_args = parser.parse_args_into_dataclasses()
else:
parsed_args = parser.parse_dict(kwargs)
Expand Down Expand Up @@ -478,6 +493,7 @@ def main(
stage_runner.predict()

# save if model was provided as a string or custom output_dir was set

if isinstance(model_args.model, str) or (
training_args.output_dir
!= TrainingArguments.__dataclass_fields__["output_dir"].default
Expand Down
3 changes: 3 additions & 0 deletions tests/llmcompressor/transformers/oneshot/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def test_one_shot_cli(self):

if len(self.additional_args) > 0:
cmd.extend(self.additional_args)

print(" ".join(cmd))

res = run_cli_command(cmd)
self.assertEqual(res.returncode, 0)
print(res.stdout)
Expand Down

0 comments on commit 7bb2e9a

Please sign in to comment.