diff --git a/examples/trl_mixin/ex_trl_distillation.py b/examples/trl_mixin/ex_trl_distillation.py index ff3ddf000..d1e392e75 100644 --- a/examples/trl_mixin/ex_trl_distillation.py +++ b/examples/trl_mixin/ex_trl_distillation.py @@ -1,9 +1,9 @@ from sft_trainer import SFTTrainer from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator -from llmcompressor.transformers import ( - DataTrainingArguments, - TextGenerationDataset, +from llmcompressor.transformers import TextGenerationDataset +from llmcompressor.transformers.utils.arg_parser import ( + DatasetArguments, TrainingArguments, ) @@ -21,7 +21,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_path) # Load gsm8k using SparseML dataset tools -data_args = DataTrainingArguments( +data_args = DatasetArguments( dataset="gsm8k", dataset_config_name="main", max_seq_length=512 ) dataset_manager = TextGenerationDataset.load_from_registry( diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index fa8e434d4..30c97df7a 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -8,12 +8,12 @@ from datasets.formatting.formatting import LazyRow from loguru import logger -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from llmcompressor.transformers.finetune.data.data_helpers import ( LABELS_MASK_VALUE, get_custom_datasets_from_path, get_raw_dataset, ) +from llmcompressor.transformers.utils.arg_parser import DatasetArguments from llmcompressor.transformers.utils.preprocessing_functions import ( PreprocessingFunctionRegistry, ) @@ -41,7 +41,7 @@ class TextGenerationDataset(RegistryMixin): def __init__( self, - data_args: DataTrainingArguments, + data_args: DatasetArguments, split: str, processor: Processor, ): diff --git a/src/llmcompressor/transformers/finetune/data/c4.py b/src/llmcompressor/transformers/finetune/data/c4.py index e50d4d0c6..bf3feeee7 100644 --- a/src/llmcompressor/transformers/finetune/data/c4.py +++ b/src/llmcompressor/transformers/finetune/data/c4.py @@ -5,7 +5,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="c4") @@ -18,7 +18,7 @@ class C4Dataset(TextGenerationDataset): :param processor: processor or tokenizer to use on dataset """ - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "allenai/c4" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py index 06ad3ecfa..506f760d0 100644 --- a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py +++ b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py @@ -5,7 +5,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="cnn_dailymail") @@ -20,7 +20,7 @@ class CNNDailyMailDataset(TextGenerationDataset): SAMPLE_TEMPLATE = "Article:\n{article}\n\n### Summarization:\n{highlights}\n" - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "cnn_dailymail" data_args.dataset_config_name = "3.0.0" diff --git a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py index 932bfa54c..ca3caec03 100644 --- a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py +++ b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py @@ -5,7 +5,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="evolcodealpaca") @@ -25,7 +25,7 @@ class EvolCodeAlpacaDataset(TextGenerationDataset): "\n\n### Response:\n" ) - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "theblackcat102/evol-codealpaca-v1" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/data/flickr_30k.py b/src/llmcompressor/transformers/finetune/data/flickr_30k.py index f19b053e1..4528c5340 100644 --- a/src/llmcompressor/transformers/finetune/data/flickr_30k.py +++ b/src/llmcompressor/transformers/finetune/data/flickr_30k.py @@ -7,7 +7,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="flickr", alias="flickr30k") @@ -31,7 +31,7 @@ class Flickr30K(TextGenerationDataset): "{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" ) - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "lmms-lab/flickr30k" diff --git a/src/llmcompressor/transformers/finetune/data/gsm8k.py b/src/llmcompressor/transformers/finetune/data/gsm8k.py index beae5dfec..8ee26145d 100644 --- a/src/llmcompressor/transformers/finetune/data/gsm8k.py +++ b/src/llmcompressor/transformers/finetune/data/gsm8k.py @@ -5,7 +5,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="gsm8k") @@ -20,7 +20,7 @@ class GSM8KDataset(TextGenerationDataset): GSM_TEMPLATE = "Question: {question}\nAnswer:" - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "gsm8k" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/data/open_platypus.py b/src/llmcompressor/transformers/finetune/data/open_platypus.py index 3b25986ca..0dbf064e5 100644 --- a/src/llmcompressor/transformers/finetune/data/open_platypus.py +++ b/src/llmcompressor/transformers/finetune/data/open_platypus.py @@ -5,7 +5,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="open_platypus") @@ -28,7 +28,7 @@ class OpenPlatypusDataset(TextGenerationDataset): "instruction}\n\n### Response:\n", } - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "garage-bAInd/Open-Platypus" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/data/ptb.py b/src/llmcompressor/transformers/finetune/data/ptb.py index c7f0bbac1..db0be0599 100644 --- a/src/llmcompressor/transformers/finetune/data/ptb.py +++ b/src/llmcompressor/transformers/finetune/data/ptb.py @@ -5,7 +5,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="ptb") @@ -18,7 +18,7 @@ class PtbDataset(TextGenerationDataset): :param processor: processor or tokenizer to use on dataset """ - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "ptb_text_only" data_args.text_column = "sentence" diff --git a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py index 62c012e83..f914ae5d4 100644 --- a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py +++ b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py @@ -7,7 +7,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="ultrachat_200k") @@ -33,7 +33,7 @@ class UltraChatDataset(TextGenerationDataset): "{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" ) - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "HuggingFaceH4/ultrachat_200k" data_args.text_column = "messages" diff --git a/src/llmcompressor/transformers/finetune/data/wikitext.py b/src/llmcompressor/transformers/finetune/data/wikitext.py index a559399d8..5e58c3c94 100644 --- a/src/llmcompressor/transformers/finetune/data/wikitext.py +++ b/src/llmcompressor/transformers/finetune/data/wikitext.py @@ -5,7 +5,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="wikitext") @@ -18,7 +18,7 @@ class WikiTextDataset(TextGenerationDataset): :param processor: processor or tokenizer to use on dataset """ - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "Salesforce/wikitext" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py index 27860aeb4..07b9ba1ef 100644 --- a/src/llmcompressor/transformers/finetune/session_mixin.py +++ b/src/llmcompressor/transformers/finetune/session_mixin.py @@ -7,13 +7,12 @@ import torch from loguru import logger from torch.nn import Module -from torch.utils.data import DataLoader, IterableDataset +from torch.utils.data import IterableDataset from transformers.trainer_callback import TrainerState from transformers.trainer_utils import get_last_checkpoint from llmcompressor.core import ( active_session, - apply, callbacks, create_session, finalize, @@ -36,8 +35,10 @@ from llmcompressor.utils.pytorch import qat_active if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments - + from llmcompressor.transformers.utils.arg_parser import ( + DatasetArguments, + ModelArguments, + ) __all__ = [ "SessionManagerMixIn", @@ -68,12 +69,14 @@ def __init__( self, recipe: Optional[str] = None, recipe_args: Optional[Union[Dict[str, Any], str]] = None, - data_args: Optional["DataTrainingArguments"] = None, + data_args: Optional["DatasetArguments"] = None, + model_args: Optional["ModelArguments"] = None, teacher: Optional[Union[Module, str]] = None, **kwargs, ): self.recipe = recipe self.recipe_args = recipe_args + self.model_args = model_args self.teacher = teacher # parse training and metadata args @@ -374,8 +377,8 @@ def train(self, *args, stage: Optional[str] = None, **kwargs): self.initialize_session(epoch=epoch, checkpoint=checkpoint, stage=stage) # do not save checkpoints as compressed - original_save_compressed = self.args.save_compressed - self.args.save_compressed = False + original_save_compressed = self.model_args.save_compressed + self.model_args.save_compressed = False # train with accelerator self.accelerator.wait_for_everyone() @@ -383,7 +386,7 @@ def train(self, *args, stage: Optional[str] = None, **kwargs): self.accelerator.wait_for_everyone() # restore original setting for saving final model - self.args.save_compressed = original_save_compressed + self.model_args.save_compressed = original_save_compressed # lifecycle self.finalize_session() @@ -428,31 +431,6 @@ def predict(self, *args, **kwargs): return output - def one_shot( - self, calibration_data: Optional[DataLoader] = None, stage: Optional[str] = None - ): - """ - Run oneshot calibration on the active model - - :param stage: which stage of the recipe to run, or None to run whole recipe - :param calib_data: dataloader of calibration data - """ - apply( - recipe=self.recipe, - recipe_stage=stage, - recipe_args=self.recipe_args, - model=self.model, - calib_data=calibration_data, - start=-1, - copy_data=False, - accelerator=self.accelerator, - min_tokens_per_module=self.min_tokens_per_module, - ) - - # log model sparsity - # self.maybe_log_model_sparsification() - self.accelerator.wait_for_everyone() - def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False): """ Override of the save_model function and expects it to exist in the parent. @@ -474,7 +452,7 @@ def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False): if not is_fsdp_model(self.model): self.model.save_pretrained( output_dir, - save_compressed=self.args.save_compressed, + save_compressed=self.model_args.save_compressed, safe_serialization=self.args.save_safetensors, ) else: # FSDP model @@ -482,7 +460,7 @@ def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False): model=self.model, accelerator=self.accelerator, output_dir=output_dir, - save_compressed=self.args.save_compressed, + save_compressed=self.model_args.save_compressed, save_safetensors=self.metadata.get("save_safetensors", False), ) diff --git a/src/llmcompressor/transformers/utils/arg_parser/data_arguments.py b/src/llmcompressor/transformers/utils/arg_parser/data_arguments.py new file mode 100644 index 000000000..50d3277f4 --- /dev/null +++ b/src/llmcompressor/transformers/utils/arg_parser/data_arguments.py @@ -0,0 +1,189 @@ +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Union + +from transformers import DefaultDataCollator + + +@dataclass +class DVCDatasetArguments: + """ + Arguments for training using DVC + """ + + dvc_data_repository: Optional[str] = field( + default=None, + metadata={"help": "Path to repository used for dvc_dataset_path"}, + ) + + +@dataclass +class CustomDatasetArguments(DVCDatasetArguments): + """ + Arguments for training using custom datasets + """ + + dataset_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Path to the custom dataset. Supports json, csv, dvc. " + "For DVC, the to dvc dataset to load, of format dvc://path. " + "For csv or json, the path containing the dataset. " + ), + }, + ) + + text_column: str = field( + default="text", + metadata={ + "help": ( + "Optional key to be used as the `text` input to tokenizer/processor " + "after dataset preprocesssing" + ) + }, + ) + + remove_columns: Union[None, str, List] = field( + default=None, + metadata={"help": "Column names to remove after preprocessing (deprecated)"}, + ) + + preprocessing_func: Union[None, str, Callable] = field( + default=None, + metadata={ + "help": ( + "Typically a function which applies a chat template. Can take the form " + "of either a function to apply to the dataset, a name defined in " + "src/llmcompressor/transformers/utils/preprocessing_functions.py, or " + "a path to a function definition of the form /path/to/file.py:func" + ) + }, + ) + + data_collator: Callable[[Any], Any] = field( + default_factory=lambda: DefaultDataCollator(), + metadata={"help": "The function to used to form a batch from the dataset"}, + ) + + +@dataclass +class DatasetArguments(CustomDatasetArguments): + """ + Arguments pertaining to what data we are going to input our model for + calibration, training or eval + + Using `HfArgumentParser` we can turn this class into argparse + arguments to be able to specify them on the command line + """ + + dataset: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The name of the dataset to use (via the datasets library). " + "Supports input as a string or DatasetDict from HF" + ) + }, + ) + dataset_config_name: Optional[str] = field( + default=None, + metadata={ + "help": ("The configuration name of the dataset to use"), + }, + ) + max_seq_length: int = field( + default=384, + metadata={ + "help": "The maximum total input sequence length after tokenization. " + "Sequences longer than this will be truncated, sequences shorter will " + "be padded." + }, + ) + concatenate_data: bool = field( + default=False, + metadata={ + "help": "Whether or not to concatenate datapoints to fill max_seq_length" + }, + ) + raw_kwargs: Dict = field( + default_factory=dict, + metadata={"help": "Additional keyboard args to pass to datasets load_data"}, + ) + splits: Union[None, str, List, Dict] = field( + default=None, + metadata={"help": "Optional percentages of each split to download"}, + ) + num_calibration_samples: Optional[int] = field( + default=512, + metadata={"help": "Number of samples to use for one-shot calibration"}, + ) + shuffle_calibration_samples: Optional[bool] = field( + default=True, + metadata={ + "help": "whether to shuffle the dataset before selecting calibration data" + }, + ) + streaming: Optional[bool] = field( + default=False, + metadata={"help": "True to stream data from a cloud dataset"}, + ) + overwrite_cache: bool = field( + default=False, + metadata={"help": "Overwrite the cached preprocessed datasets or not."}, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + pad_to_max_length: bool = field( + default=True, + metadata={ + "help": "Whether to pad all samples to `max_seq_length`. If False, " + "will pad the samples dynamically when batching to the maximum length " + "in the batch (which can be faster on GPU but will be slower on TPU)." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number " + "of training examples to this value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number " + "of evaluation examples to this value if set." + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of " + "prediction examples to this value if set." + ), + }, + ) + min_tokens_per_module: Optional[float] = field( + default=None, + metadata={ + "help": ( + "The minimum percentage of tokens (out of the total number) " + "that the module should 'receive' throughout the forward " + "pass of the calibration. If a module receives fewer tokens, " + "a warning will be logged. Defaults to 1/num_of_experts." + "note: this argument is only relevant for MoE models" + ), + }, + ) + trust_remote_code_data: bool = field( + default=False, + metadata={ + "help": "Whether or not to allow for datasets defined on the Hub using " + "a dataset script. This option should only be set to True for " + "repositories you trust and in which you have read the code, as it " + "will execute code present on the Hub on your local machine." + }, + ) diff --git a/src/llmcompressor/transformers/utils/arg_parser/model_arguments.py b/src/llmcompressor/transformers/utils/arg_parser/model_arguments.py new file mode 100644 index 000000000..ce424812a --- /dev/null +++ b/src/llmcompressor/transformers/utils/arg_parser/model_arguments.py @@ -0,0 +1,92 @@ +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class ModelArguments: + """ + Model variables used for oneshot calibration, training or finetuning and + stage runners (combination of oneshot and finetune going back and forth) + + """ + + model: str = field( + metadata={ + "help": ( + "A pretrained model or a string as a path to pretrained model, " + "HF stub, or model identifier from huggingface.co/models." + ) + }, + ) + distill_teacher: Optional[str] = field( + default=None, + metadata={ + "help": "Teacher model (a trained text generation model)", + }, + ) + config_name: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained config name or path if not the same as model_name" + }, + ) + tokenizer: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained tokenizer name or path if not the same as model_name" + }, + ) + processor: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained processor name or path if not the same as model_name" + }, + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained data from huggingface.co"}, + ) + + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use token generated when running `transformers-cli login` " + "(necessary to use this script with private models)" + }, + ) + precision: str = field( + default="auto", + metadata={"help": "Precision to cast model weights to, default to auto"}, + ) + + tie_word_embeddings: bool = field( + default=False, + metadata={ + "help": "Whether the model's input and output word embeddings " + "should be tied. Note that this is only relevant if the " + "model has a output word embedding layer." + }, + ) + trust_remote_code_model: bool = field( + default=False, + metadata={ + "help": "Whether or not to allow for custom models to execute their " + "own modeling files. This option should only be set to True for " + "repositories you trust and in which you have read the code" + }, + ) + save_compressed: Optional[bool] = field( + default=True, + metadata={"help": "Whether to compress sparse models during save"}, + ) + oneshot_device: Optional[str] = field( + default="cuda:0", + metadata={"help": "Device to run oneshot calibration on"}, + ) + model_revision: str = field( + default="main", + metadata={ + "help": "The specific model version to use " + "(can be a branch name, tag name or commit id)" + }, + ) diff --git a/src/llmcompressor/transformers/utils/arg_parser/training_arguments.py b/src/llmcompressor/transformers/utils/arg_parser/training_arguments.py new file mode 100644 index 000000000..7b61193b0 --- /dev/null +++ b/src/llmcompressor/transformers/utils/arg_parser/training_arguments.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass, field +from typing import Optional + +from transformers import TrainingArguments as HFTrainingArgs + +__all__ = ["TrainingArguments", "DEFAULT_OUTPUT_DIR"] + +DEFAULT_OUTPUT_DIR = "./output" + + +@dataclass +class TrainingArguments(HFTrainingArgs): + """ + Training arguments specific to LLM Compressor Transformers workflow using + HFTrainingArgs as base class + + """ + + do_oneshot: Optional[bool] = field( + default=False, + metadata={"help": "Whether to run one-shot calibration in stages"}, + ) + run_stages: Optional[bool] = field( + default=False, metadata={"help": "Whether to trigger recipe stage by stage"} + ) + output_dir: str = field( + default=DEFAULT_OUTPUT_DIR, + metadata={ + "help": "The output directory where the model predictions and " + "checkpoints will be written." + }, + ) diff --git a/src/llmcompressor/transformers/utils/helpers.py b/src/llmcompressor/transformers/utils/helpers.py index c1dcef119..80c4b446e 100644 --- a/src/llmcompressor/transformers/utils/helpers.py +++ b/src/llmcompressor/transformers/utils/helpers.py @@ -14,7 +14,10 @@ from transformers.trainer_utils import get_last_checkpoint if TYPE_CHECKING: - from llmcompressor.transformers import ModelArguments, TrainingArguments + from llmcompressor.transformers.utils.arg_parser import ( + ModelArguments, + TrainingArguments, + ) __all__ = [ "RECIPE_FILE_NAME", diff --git a/tests/llmcompressor/transformers/compression/test_quantization.py b/tests/llmcompressor/transformers/compression/test_quantization.py index 13eab66c9..cefcdaa54 100644 --- a/tests/llmcompressor/transformers/compression/test_quantization.py +++ b/tests/llmcompressor/transformers/compression/test_quantization.py @@ -13,7 +13,7 @@ from llmcompressor.pytorch.utils import tensors_to_device from llmcompressor.transformers import oneshot from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments +from llmcompressor.transformers.utils.arg_parser import DatasetArguments from tests.testing_utils import parse_params, requires_gpu CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/compression/configs" @@ -59,10 +59,9 @@ def _run_oneshot(model, recipe, dataset, output_dir): max_seq_length = 512 pad_to_max_length = False - oneshot( + oneshot_run = oneshot( model=model, dataset=dataset, - overwrite_output_dir=True, output_dir=output_dir, max_seq_length=max_seq_length, num_calibration_samples=num_calibration_samples, @@ -72,10 +71,8 @@ def _run_oneshot(model, recipe, dataset, output_dir): splits={"calibration": "train_gen[:5%]"}, save_compressed=False, ) - from llmcompressor.pytorch.model_load.helpers import get_session_model - # note: get_session_model() is None outside of function scope - return get_session_model() + return oneshot_run.model def _get_quant_info(self, model): quant_info_weights = {} @@ -147,7 +144,7 @@ def _get_dataloader(self, data_args, tokenizer): @torch.no_grad() def test_perplexity(self): tokenizer = AutoTokenizer.from_pretrained(self.model_stub) - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset="ultrachat-200k", max_seq_length=self.max_seq_length, ) diff --git a/tests/llmcompressor/transformers/finetune/data/conftest.py b/tests/llmcompressor/transformers/finetune/data/conftest.py index a7a347d99..a4182721d 100644 --- a/tests/llmcompressor/transformers/finetune/data/conftest.py +++ b/tests/llmcompressor/transformers/finetune/data/conftest.py @@ -1,7 +1,7 @@ import pytest from transformers import AutoTokenizer -from llmcompressor.transformers.finetune.model_args import ModelArguments +from llmcompressor.transformers.utils.arg_parser import ModelArguments @pytest.fixture diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py index 812b26a56..4b907b6a0 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py @@ -1,15 +1,15 @@ import pytest -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from llmcompressor.transformers.finetune.data.data_helpers import ( get_raw_dataset, make_dataset_splits, ) +from llmcompressor.transformers.utils.arg_parser import DatasetArguments @pytest.mark.unit def test_combined_datasets(): - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1" ) raw_wikitext2 = get_raw_dataset(data_args) @@ -33,7 +33,7 @@ def test_combined_datasets(): @pytest.mark.unit def test_separate_datasets(): splits = {"train": "train[:10%]", "validation": "train[10%:20%]"} - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1" ) datasets = {} diff --git a/tests/llmcompressor/transformers/finetune/data/test_registry.py b/tests/llmcompressor/transformers/finetune/data/test_registry.py index 9aee4c20f..11dc9034f 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_registry.py +++ b/tests/llmcompressor/transformers/finetune/data/test_registry.py @@ -6,12 +6,12 @@ TextGenerationDataset, WikiTextDataset, ) -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments +from llmcompressor.transformers.utils.arg_parser import DatasetArguments @pytest.mark.usefixtures("tiny_llama_tokenizer") def test_c4_initializes(tiny_llama_tokenizer): - data_args = DataTrainingArguments(dataset="c4", concatenate_data=True) + data_args = DatasetArguments(dataset="c4", concatenate_data=True) c4_manager = TextGenerationDataset.load_from_registry( data_args.dataset, data_args=data_args, @@ -27,7 +27,7 @@ def test_c4_initializes(tiny_llama_tokenizer): @pytest.mark.usefixtures("tiny_llama_tokenizer") def test_wikitext_initializes(tiny_llama_tokenizer): - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1" ) wiki_manager = TextGenerationDataset.load_from_registry( @@ -45,7 +45,7 @@ def test_wikitext_initializes(tiny_llama_tokenizer): @pytest.mark.usefixtures("tiny_llama_tokenizer") def test_open_platypus_initializes(tiny_llama_tokenizer): - data_args = DataTrainingArguments(dataset="open_platypus", pad_to_max_length=False) + data_args = DatasetArguments(dataset="open_platypus", pad_to_max_length=False) op_manager = TextGenerationDataset.load_from_registry( data_args.dataset, data_args=data_args, diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_completion.py b/tests/llmcompressor/transformers/obcq/test_obcq_completion.py index fe699570a..6a6fc9bf3 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_completion.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_completion.py @@ -23,12 +23,10 @@ def labeled_dataloader(self, dataset_name, model_name): from transformers import AutoTokenizer, DefaultDataCollator from llmcompressor.transformers.finetune.data import TextGenerationDataset - from llmcompressor.transformers.finetune.data.data_args import ( - DataTrainingArguments, - ) + from llmcompressor.transformers.utils.arg_parser import DatasetArguments tokenizer = AutoTokenizer.from_pretrained(model_name) - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset=dataset_name, max_seq_length=512, pad_to_max_length=False,