diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index 137ba53dd..71c492d1f 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -26,9 +26,6 @@ RecipeArguments, TrainingArguments, ) -from llmcompressor.transformers.utils.arg_parser.training_arguments import ( - DEFAULT_OUTPUT_DIR, -) from llmcompressor.typing import Processor from llmcompressor.utils.fsdp.helpers import is_fsdp_model, save_model_and_recipe @@ -264,7 +261,7 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None): self.train(checkpoint=checkpoint, stage=stage_name) checkpoint = None - if self._training_args.output_dir != DEFAULT_OUTPUT_DIR: + if self._training_args.output_dir: save_model_and_recipe( model=self.trainer.model, save_path=self._output_dir, diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index 7b5c1c1a9..3b16b6dac 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -52,7 +52,6 @@ get_shared_processor_src, ) from llmcompressor.transformers.utils.arg_parser import ( - DEFAULT_OUTPUT_DIR, DatasetArguments, ModelArguments, RecipeArguments, @@ -153,65 +152,39 @@ def parse_args(include_training_args: bool = False, **kwargs): conflict with Accelerate library's accelerator. """ - output_dir = kwargs.pop("output_dir", DEFAULT_OUTPUT_DIR) - - if include_training_args: - parser = HfArgumentParser( - (ModelArguments, DatasetArguments, RecipeArguments, TrainingArguments) - ) - else: - parser = HfArgumentParser((ModelArguments, DatasetArguments, RecipeArguments)) + output_dir = kwargs.pop("output_dir", None) + parser = HfArgumentParser(_get_dataclass_arguments(include_training_args)) + # parse from kwargs or cli if not kwargs: - - def _get_output_dir_from_argv() -> Optional[str]: - import sys - - output_dir = None - if "--output_dir" in sys.argv: - index = sys.argv.index("--output_dir") - sys.argv.pop(index) - if index < len(sys.argv): # Check if value exists afer the flag - output_dir = sys.argv.pop(index) - - return output_dir - - output_dir = _get_output_dir_from_argv() or output_dir + output_dir = _get_output_dir_from_argv() parsed_args = parser.parse_args_into_dataclasses() else: parsed_args = parser.parse_dict(kwargs) - # Unpack parsed arguments based on the presence of training arguments + # populate args, oneshot does not need training_args if include_training_args: model_args, data_args, recipe_args, training_args = parsed_args - if output_dir is not None: + if output_dir: training_args.output_dir = output_dir else: model_args, data_args, recipe_args = parsed_args training_args = None - if recipe_args.recipe_args is not None: - if not isinstance(recipe_args.recipe_args, dict): - recipe_args.recipe_args = { - key: value - for arg in recipe_args.recipe_args - for key, value in [arg.split("=")] - } + # populate recipe arguments + if recipe_args.recipe_args: + recipe_args.recipe_args = _unwrap_recipe_args(recipe_args.recipe_args) - # Raise deprecation warnings if data_args.remove_columns is not None: warnings.warn( - "`remove_columns` argument is deprecated. When tokenizing datasets, all " - "columns which are invalid inputs to the tokenizer will be removed.", + ( + "`remove_columns` is deprecated." + "Invalid columns for tokenizers are removed automatically.", + ), DeprecationWarning, ) - # Silently assign tokenizer to processor - if model_args.tokenizer: - if model_args.processor: - raise ValueError("Cannot use both a tokenizer and processor.") - model_args.processor = model_args.tokenizer - model_args.tokenizer = None + _validate_model_args_tokenizer(model_args) return model_args, data_args, recipe_args, training_args, output_dir @@ -509,5 +482,46 @@ def main( reset_session() +def _validate_model_args_tokenizer(model_args): + """Ensure only one of tokenizer or processor is set""" + + if model_args.tokenizer: + if model_args.processor: + raise ValueError("Cannot use both a tokenizer and processor.") + model_args.processor, model_args.tokenizer = model_args.tokenizer, None + + +def _get_output_dir_from_argv() -> Optional[str]: + """Extract output directory from command-line arguments""" + + import sys + + if "--output_dir" in sys.argv: + index = sys.argv.index("--output_dir") + sys.argv.pop(index) + if index < len(sys.argv): + return sys.argv.pop(index) + + return None + + +def _get_dataclass_arguments(include_training_args: bool): + """Return the appropriate argument classes for parsing""" + + dataclass_arguments = (ModelArguments, DatasetArguments, RecipeArguments) + if include_training_args: + return dataclass_arguments + (TrainingArguments,) + + return dataclass_arguments + + +def _unwrap_recipe_args(recipe_args): + """Convert recipe arguments to a dictionary if needed""" + if isinstance(recipe_args, dict): + return recipe_args + + return {key: value for arg in recipe_args for key, value in [arg.split("=")]} + + if __name__ == "__main__": apply() diff --git a/src/llmcompressor/transformers/utils/arg_parser/__init__.py b/src/llmcompressor/transformers/utils/arg_parser/__init__.py index cbb9224af..5973efb94 100644 --- a/src/llmcompressor/transformers/utils/arg_parser/__init__.py +++ b/src/llmcompressor/transformers/utils/arg_parser/__init__.py @@ -3,4 +3,4 @@ from .data_arguments import DatasetArguments from .model_arguments import ModelArguments from .recipe_arguments import RecipeArguments -from .training_arguments import DEFAULT_OUTPUT_DIR, TrainingArguments +from .training_arguments import TrainingArguments diff --git a/src/llmcompressor/transformers/utils/arg_parser/training_arguments.py b/src/llmcompressor/transformers/utils/arg_parser/training_arguments.py index 7b61193b0..636424375 100644 --- a/src/llmcompressor/transformers/utils/arg_parser/training_arguments.py +++ b/src/llmcompressor/transformers/utils/arg_parser/training_arguments.py @@ -3,9 +3,9 @@ from transformers import TrainingArguments as HFTrainingArgs -__all__ = ["TrainingArguments", "DEFAULT_OUTPUT_DIR"] - -DEFAULT_OUTPUT_DIR = "./output" +__all__ = [ + "TrainingArguments", +] @dataclass @@ -23,8 +23,8 @@ class TrainingArguments(HFTrainingArgs): run_stages: Optional[bool] = field( default=False, metadata={"help": "Whether to trigger recipe stage by stage"} ) - output_dir: str = field( - default=DEFAULT_OUTPUT_DIR, + output_dir: Optional[str] = field( + default=None, metadata={ "help": "The output directory where the model predictions and " "checkpoints will be written." diff --git a/src/llmcompressor/transformers/utils/arg_parser/utils.py b/src/llmcompressor/transformers/utils/arg_parser/utils.py deleted file mode 100644 index 48455fa15..000000000 --- a/src/llmcompressor/transformers/utils/arg_parser/utils.py +++ /dev/null @@ -1,30 +0,0 @@ -from dataclasses import fields -from typing import Any, Dict, Union - -from .data_arguments import DatasetArguments -from .model_arguments import ModelArguments -from .recipe_arguments import RecipeArguments -from .training_arguments import TrainingArguments - -__all__ = [ - "get_dataclass_as_dict", -] - - -def get_dataclass_as_dict( - dataclass_instance: Union[ - "ModelArguments", "RecipeArguments", "DatasetArguments", "TrainingArguments" - ], - dataclass_class: Union[ - "ModelArguments", "RecipeArguments", "DatasetArguments", "TrainingArguments" - ], -) -> Dict[str, Any]: - """ - Get the dataclass instance attributes as a dict, neglicting the inherited class. - Ex. dataclass_class=TrainingArguments will ignore HFTrainignArguments - - """ - return { - field.name: getattr(dataclass_instance, field.name) - for field in fields(dataclass_class) - }