Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Jan 31, 2025
1 parent 12a2167 commit 77205c0
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 81 deletions.
5 changes: 1 addition & 4 deletions src/llmcompressor/transformers/finetune/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
RecipeArguments,
TrainingArguments,
)
from llmcompressor.transformers.utils.arg_parser.training_arguments import (
DEFAULT_OUTPUT_DIR,
)
from llmcompressor.typing import Processor
from llmcompressor.utils.fsdp.helpers import is_fsdp_model, save_model_and_recipe

Expand Down Expand Up @@ -264,7 +261,7 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None):
self.train(checkpoint=checkpoint, stage=stage_name)
checkpoint = None

if self._training_args.output_dir != DEFAULT_OUTPUT_DIR:
if self._training_args.output_dir:
save_model_and_recipe(
model=self.trainer.model,
save_path=self._output_dir,
Expand Down
96 changes: 55 additions & 41 deletions src/llmcompressor/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
get_shared_processor_src,
)
from llmcompressor.transformers.utils.arg_parser import (
DEFAULT_OUTPUT_DIR,
DatasetArguments,
ModelArguments,
RecipeArguments,
Expand Down Expand Up @@ -153,65 +152,39 @@ def parse_args(include_training_args: bool = False, **kwargs):
conflict with Accelerate library's accelerator.
"""
output_dir = kwargs.pop("output_dir", DEFAULT_OUTPUT_DIR)

if include_training_args:
parser = HfArgumentParser(
(ModelArguments, DatasetArguments, RecipeArguments, TrainingArguments)
)
else:
parser = HfArgumentParser((ModelArguments, DatasetArguments, RecipeArguments))
output_dir = kwargs.pop("output_dir", None)
parser = HfArgumentParser(_get_dataclass_arguments(include_training_args))

# parse from kwargs or cli
if not kwargs:

def _get_output_dir_from_argv() -> Optional[str]:
import sys

output_dir = None
if "--output_dir" in sys.argv:
index = sys.argv.index("--output_dir")
sys.argv.pop(index)
if index < len(sys.argv): # Check if value exists afer the flag
output_dir = sys.argv.pop(index)

return output_dir

output_dir = _get_output_dir_from_argv() or output_dir
output_dir = _get_output_dir_from_argv()
parsed_args = parser.parse_args_into_dataclasses()
else:
parsed_args = parser.parse_dict(kwargs)

# Unpack parsed arguments based on the presence of training arguments
# populate args, oneshot does not need training_args
if include_training_args:
model_args, data_args, recipe_args, training_args = parsed_args
if output_dir is not None:
if output_dir:
training_args.output_dir = output_dir
else:
model_args, data_args, recipe_args = parsed_args
training_args = None

if recipe_args.recipe_args is not None:
if not isinstance(recipe_args.recipe_args, dict):
recipe_args.recipe_args = {
key: value
for arg in recipe_args.recipe_args
for key, value in [arg.split("=")]
}
# populate recipe arguments
if recipe_args.recipe_args:
recipe_args.recipe_args = _unwrap_recipe_args(recipe_args.recipe_args)

# Raise deprecation warnings
if data_args.remove_columns is not None:
warnings.warn(
"`remove_columns` argument is deprecated. When tokenizing datasets, all "
"columns which are invalid inputs to the tokenizer will be removed.",
(
"`remove_columns` is deprecated."
"Invalid columns for tokenizers are removed automatically.",
),
DeprecationWarning,
)

# Silently assign tokenizer to processor
if model_args.tokenizer:
if model_args.processor:
raise ValueError("Cannot use both a tokenizer and processor.")
model_args.processor = model_args.tokenizer
model_args.tokenizer = None
_validate_model_args_tokenizer(model_args)

return model_args, data_args, recipe_args, training_args, output_dir

Expand Down Expand Up @@ -509,5 +482,46 @@ def main(
reset_session()


def _validate_model_args_tokenizer(model_args):
"""Ensure only one of tokenizer or processor is set"""

if model_args.tokenizer:
if model_args.processor:
raise ValueError("Cannot use both a tokenizer and processor.")
model_args.processor, model_args.tokenizer = model_args.tokenizer, None


def _get_output_dir_from_argv() -> Optional[str]:
"""Extract output directory from command-line arguments"""

import sys

if "--output_dir" in sys.argv:
index = sys.argv.index("--output_dir")
sys.argv.pop(index)
if index < len(sys.argv):
return sys.argv.pop(index)

return None


def _get_dataclass_arguments(include_training_args: bool):
"""Return the appropriate argument classes for parsing"""

dataclass_arguments = (ModelArguments, DatasetArguments, RecipeArguments)
if include_training_args:
return dataclass_arguments + (TrainingArguments,)

return dataclass_arguments


def _unwrap_recipe_args(recipe_args):
"""Convert recipe arguments to a dictionary if needed"""
if isinstance(recipe_args, dict):
return recipe_args

return {key: value for arg in recipe_args for key, value in [arg.split("=")]}


if __name__ == "__main__":
apply()
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from .data_arguments import DatasetArguments
from .model_arguments import ModelArguments
from .recipe_arguments import RecipeArguments
from .training_arguments import DEFAULT_OUTPUT_DIR, TrainingArguments
from .training_arguments import TrainingArguments
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from transformers import TrainingArguments as HFTrainingArgs

__all__ = ["TrainingArguments", "DEFAULT_OUTPUT_DIR"]

DEFAULT_OUTPUT_DIR = "./output"
__all__ = [
"TrainingArguments",
]


@dataclass
Expand All @@ -23,8 +23,8 @@ class TrainingArguments(HFTrainingArgs):
run_stages: Optional[bool] = field(
default=False, metadata={"help": "Whether to trigger recipe stage by stage"}
)
output_dir: str = field(
default=DEFAULT_OUTPUT_DIR,
output_dir: Optional[str] = field(
default=None,
metadata={
"help": "The output directory where the model predictions and "
"checkpoints will be written."
Expand Down
30 changes: 0 additions & 30 deletions src/llmcompressor/transformers/utils/arg_parser/utils.py

This file was deleted.

0 comments on commit 77205c0

Please sign in to comment.