diff --git a/src/llmcompressor/transformers/finetune/__init__.py b/src/llmcompressor/transformers/finetune/__init__.py index aad70ae2c..6c75b902b 100644 --- a/src/llmcompressor/transformers/finetune/__init__.py +++ b/src/llmcompressor/transformers/finetune/__init__.py @@ -1,7 +1,5 @@ # flake8: noqa -from .data import DataTrainingArguments, TextGenerationDataset -from .model_args import ModelArguments +from .data import TextGenerationDataset from .session_mixin import SessionManagerMixIn from .text_generation import apply, compress, eval, oneshot, train -from .training_args import TrainingArguments diff --git a/src/llmcompressor/transformers/finetune/data/__init__.py b/src/llmcompressor/transformers/finetune/data/__init__.py index ddf0b2364..a53caed1b 100644 --- a/src/llmcompressor/transformers/finetune/data/__init__.py +++ b/src/llmcompressor/transformers/finetune/data/__init__.py @@ -4,7 +4,6 @@ from .c4 import C4Dataset from .cnn_dailymail import CNNDailyMailDataset from .custom import CustomDataset -from .data_args import DataTrainingArguments from .evolcodealpaca import EvolCodeAlpacaDataset from .flickr_30k import Flickr30K from .gsm8k import GSM8KDataset diff --git a/src/llmcompressor/transformers/finetune/data/data_args.py b/src/llmcompressor/transformers/finetune/data/data_args.py deleted file mode 100644 index 7d0bc14ce..000000000 --- a/src/llmcompressor/transformers/finetune/data/data_args.py +++ /dev/null @@ -1,189 +0,0 @@ -from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Optional, Union - -from transformers import DefaultDataCollator - - -@dataclass -class DVCDatasetTrainingArguments: - """ - Arguments for training using DVC - """ - - dvc_data_repository: Optional[str] = field( - default=None, - metadata={"help": "Path to repository used for dvc_dataset_path"}, - ) - - -@dataclass -class CustomDataTrainingArguments(DVCDatasetTrainingArguments): - """ - Arguments for training using custom datasets - """ - - dataset_path: Optional[str] = field( - default=None, - metadata={ - "help": ( - "Path to the custom dataset. Supports json, csv, dvc. " - "For DVC, the to dvc dataset to load, of format dvc://path. " - "For csv or json, the path containing the dataset. " - ), - }, - ) - - text_column: str = field( - default="text", - metadata={ - "help": ( - "Optional key to be used as the `text` input to tokenizer/processor " - "after dataset preprocesssing" - ) - }, - ) - - remove_columns: Union[None, str, List] = field( - default=None, - metadata={"help": "Column names to remove after preprocessing (deprecated)"}, - ) - - preprocessing_func: Union[None, str, Callable] = field( - default=None, - metadata={ - "help": ( - "Typically a function which applies a chat template. Can take the form " - "of either a function to apply to the dataset, a name defined in " - "src/llmcompressor/transformers/utils/preprocessing_functions.py, or " - "a path to a function definition of the form /path/to/file.py:func" - ) - }, - ) - - data_collator: Callable[[Any], Any] = field( - default_factory=lambda: DefaultDataCollator(), - metadata={"help": "The function to used to form a batch from the dataset"}, - ) - - -@dataclass -class DataTrainingArguments(CustomDataTrainingArguments): - """ - Arguments pertaining to what data we are going to input our model for - training and eval - - Using `HfArgumentParser` we can turn this class into argparse - arguments to be able to specify them on the command line - """ - - dataset: Optional[str] = field( - default=None, - metadata={ - "help": ( - "The name of the dataset to use (via the datasets library). " - "Supports input as a string or DatasetDict from HF" - ) - }, - ) - dataset_config_name: Optional[str] = field( - default=None, - metadata={ - "help": ("The configuration name of the dataset to use"), - }, - ) - max_seq_length: int = field( - default=384, - metadata={ - "help": "The maximum total input sequence length after tokenization. " - "Sequences longer than this will be truncated, sequences shorter will " - "be padded." - }, - ) - concatenate_data: bool = field( - default=False, - metadata={ - "help": "Whether or not to concatenate datapoints to fill max_seq_length" - }, - ) - raw_kwargs: Dict = field( - default_factory=dict, - metadata={"help": "Additional keyboard args to pass to datasets load_data"}, - ) - splits: Union[None, str, List, Dict] = field( - default=None, - metadata={"help": "Optional percentages of each split to download"}, - ) - num_calibration_samples: Optional[int] = field( - default=512, - metadata={"help": "Number of samples to use for one-shot calibration"}, - ) - shuffle_calibration_samples: Optional[bool] = field( - default=True, - metadata={ - "help": "whether to shuffle the dataset before selecting calibration data" - }, - ) - streaming: Optional[bool] = field( - default=False, - metadata={"help": "True to stream data from a cloud dataset"}, - ) - overwrite_cache: bool = field( - default=False, - metadata={"help": "Overwrite the cached preprocessed datasets or not."}, - ) - preprocessing_num_workers: Optional[int] = field( - default=None, - metadata={"help": "The number of processes to use for the preprocessing."}, - ) - pad_to_max_length: bool = field( - default=True, - metadata={ - "help": "Whether to pad all samples to `max_seq_length`. If False, " - "will pad the samples dynamically when batching to the maximum length " - "in the batch (which can be faster on GPU but will be slower on TPU)." - }, - ) - max_train_samples: Optional[int] = field( - default=None, - metadata={ - "help": "For debugging purposes or quicker training, truncate the number " - "of training examples to this value if set." - }, - ) - max_eval_samples: Optional[int] = field( - default=None, - metadata={ - "help": "For debugging purposes or quicker training, truncate the number " - "of evaluation examples to this value if set." - }, - ) - max_predict_samples: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For debugging purposes or quicker training, truncate the number of " - "prediction examples to this value if set." - ), - }, - ) - min_tokens_per_module: Optional[float] = field( - default=None, - metadata={ - "help": ( - "The minimum percentage of tokens (out of the total number) " - "that the module should 'receive' throughout the forward " - "pass of the calibration. If a module receives fewer tokens, " - "a warning will be logged. Defaults to 1/num_of_experts." - "note: this argument is only relevant for MoE models" - ), - }, - ) - trust_remote_code_data: bool = field( - default=False, - metadata={ - "help": "Whether or not to allow for datasets defined on the Hub using " - "a dataset script. This option should only be set to True for " - "repositories you trust and in which you have read the code, as it " - "will execute code present on the Hub on your local machine." - }, - ) diff --git a/src/llmcompressor/transformers/finetune/model_args.py b/src/llmcompressor/transformers/finetune/model_args.py deleted file mode 100644 index c81900ee2..000000000 --- a/src/llmcompressor/transformers/finetune/model_args.py +++ /dev/null @@ -1,85 +0,0 @@ -from dataclasses import dataclass, field -from typing import Optional - - -@dataclass -class ModelArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune from - """ - - model: str = field( - metadata={ - "help": ( - "A pretrained model or a string as a path to pretrained model, " - "HF stub, or model identifier from huggingface.co/models." - ) - }, - ) - distill_teacher: Optional[str] = field( - default=None, - metadata={ - "help": "Teacher model (a trained text generation model)", - }, - ) - config_name: Optional[str] = field( - default=None, - metadata={ - "help": "Pretrained config name or path if not the same as model_name" - }, - ) - tokenizer: Optional[str] = field( - default=None, - metadata={ - "help": "Pretrained tokenizer name or path if not the same as model_name" - }, - ) - processor: Optional[str] = field( - default=None, - metadata={ - "help": "Pretrained processor name or path if not the same as model_name" - }, - ) - cache_dir: Optional[str] = field( - default=None, - metadata={"help": "Where to store the pretrained data from huggingface.co"}, - ) - use_fast_tokenizer: bool = field( - default=True, - metadata={"help": "Whether to use one of the fast tokenizers. Default True"}, - ) - model_revision: str = field( - default="main", - metadata={ - "help": "The specific model version to use " - "(can be a branch name, tag name or commit id)" - }, - ) - use_auth_token: bool = field( - default=False, - metadata={ - "help": "Will use token generated when running `transformers-cli login` " - "(necessary to use this script with private models)" - }, - ) - precision: str = field( - default="auto", - metadata={"help": "Precision to cast model weights to, default to auto"}, - ) - - tie_word_embeddings: bool = field( - default=False, - metadata={ - "help": "Whether the model's input and output word embeddings " - "should be tied. Note that this is only relevant if the " - "model has a output word embedding layer." - }, - ) - trust_remote_code_model: bool = field( - default=False, - metadata={ - "help": "Whether or not to allow for custom models to execute their " - "own modeling files. This option should only be set to True for " - "repositories you trust and in which you have read the code" - }, - ) diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index c1aec5164..137ba53dd 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -29,7 +29,6 @@ from llmcompressor.transformers.utils.arg_parser.training_arguments import ( DEFAULT_OUTPUT_DIR, ) -from llmcompressor.transformers.utils.arg_parser.utils import get_dataclass_as_dict from llmcompressor.typing import Processor from llmcompressor.utils.fsdp.helpers import is_fsdp_model, save_model_and_recipe @@ -260,19 +259,7 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None): # run stage if run_type is StageRunType.ONESHOT: - from llmcompressor.transformers.calibration import Oneshot - - model = get_session_model() - self._model_args.model = model - - oneshot = Oneshot( - output_dir=self._training_args.output_dir, - **get_dataclass_as_dict(self._model_args, ModelArguments), - **get_dataclass_as_dict(self._data_args, DatasetArguments), - **get_dataclass_as_dict(self._recipe_args, RecipeArguments), - ) - - oneshot.run(stage_name=stage_name) + self.one_shot(stage=stage_name) elif run_type is StageRunType.TRAIN: self.train(checkpoint=checkpoint, stage=stage_name) checkpoint = None diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py index 07b9ba1ef..56eabefcb 100644 --- a/src/llmcompressor/transformers/finetune/session_mixin.py +++ b/src/llmcompressor/transformers/finetune/session_mixin.py @@ -7,12 +7,13 @@ import torch from loguru import logger from torch.nn import Module -from torch.utils.data import IterableDataset +from torch.utils.data import DataLoader, IterableDataset from transformers.trainer_callback import TrainerState from transformers.trainer_utils import get_last_checkpoint from llmcompressor.core import ( active_session, + apply, callbacks, create_session, finalize, @@ -431,6 +432,30 @@ def predict(self, *args, **kwargs): return output + def one_shot( + self, calibration_data: Optional[DataLoader] = None, stage: Optional[str] = None + ): + """ + Run oneshot calibration on the active model + :param stage: which stage of the recipe to run, or None to run whole recipe + :param calib_data: dataloader of calibration data + """ + apply( + recipe=self.recipe, + recipe_stage=stage, + recipe_args=self.recipe_args, + model=self.model, + calib_data=calibration_data, + start=-1, + copy_data=False, + accelerator=self.accelerator, + min_tokens_per_module=self.min_tokens_per_module, + ) + + # log model sparsity + # self.maybe_log_model_sparsification() + self.accelerator.wait_for_everyone() + def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False): """ Override of the save_model function and expects it to exist in the parent. diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index 6c71610a9..922e556ee 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -49,7 +49,7 @@ patch_tied_tensors_bug, ) from llmcompressor.transformers.sparsification.sparse_model import ( - get_processor_from_model, + get_shared_processor_src, ) from llmcompressor.transformers.utils.arg_parser import ( DEFAULT_OUTPUT_DIR, @@ -89,14 +89,15 @@ def eval(**kwargs): def oneshot(**kwargs): - from llmcompressor.transformers.calibration.oneshot import Oneshot - """ CLI entrypoint for running oneshot calibration """ - oneshot = Oneshot(**kwargs) - oneshot.run() - return oneshot + # TODO: Get rid of training args when Oneshot refactor comes in + model_args, data_args, recipe_args, training_args, _ = parse_args( + include_training_args=True, **kwargs + ) + training_args.do_oneshot = True + main(model_args, data_args, recipe_args, training_args) # alias @@ -161,21 +162,6 @@ def parse_args(include_training_args: bool = False, **kwargs): parser = HfArgumentParser((ModelArguments, DatasetArguments, RecipeArguments)) if not kwargs: - # if output_dir passed from cli, pop to avoid using training_args - def _get_output_dir_from_argv() -> Optional[str]: - import sys - - output_dir = None - if "--output_dir" in sys.argv: - index = sys.argv.index("--output_dir") - sys.argv.pop(index) - if index < len(sys.argv): # Check if value exists afer the flag - output_dir = sys.argv.pop(index) - - return output_dir - - output_dir = _get_output_dir_from_argv() or output_dir - parsed_args = parser.parse_args_into_dataclasses() else: parsed_args = parser.parse_dict(kwargs) @@ -217,8 +203,9 @@ def _get_output_dir_from_argv() -> Optional[str]: def initialize_model_from_path( model_args: ModelArguments, - training_args: Optional[TrainingArguments] = None, + training_args: TrainingArguments, ): + last_checkpoint = detect_last_checkpoint(training_args, model_args=model_args) # Load pretrained model # The .from_pretrained methods guarantee that only one local process can # concurrently download model & vocab. @@ -231,23 +218,16 @@ def initialize_model_from_path( tie_word_embeddings=model_args.tie_word_embeddings, trust_remote_code=model_args.trust_remote_code_model, ) - - last_checkpoint = None - - if training_args is not None: - teacher_config = ( - AutoConfig.from_pretrained( - model_args.distill_teacher, - use_auth_token=True if model_args.use_auth_token else None, - tie_word_embeddings=model_args.tie_word_embeddings, - trust_remote_code=model_args.trust_remote_code_model, - ) - if model_args.distill_teacher - else None + teacher_config = ( + AutoConfig.from_pretrained( + model_args.distill_teacher, + use_auth_token=True if model_args.use_auth_token else None, + tie_word_embeddings=model_args.tie_word_embeddings, + trust_remote_code=model_args.trust_remote_code_model, ) - last_checkpoint = detect_last_checkpoint(training_args, model_args=model_args) - # Set seed before initializing model. - set_seed(training_args.seed) + if model_args.distill_teacher + else None + ) model_path = ( last_checkpoint or model_args.model @@ -255,18 +235,21 @@ def initialize_model_from_path( else model_args.model_name_or_path ) + # Set seed before initializing model. + set_seed(training_args.seed) + # Fallback to CPU if GPU requested and not available - model_args.oneshot_device = fallback_to_cpu(model_args.oneshot_device) + training_args.oneshot_device = fallback_to_cpu(training_args.oneshot_device) # Trainer handles device assignment for FSDP and training, don't do mapping here # if running oneshot outside of FSDP, apply user device settings - + device_map = None fsdp_enabled = os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" - - device_map = model_args.oneshot_device - if not fsdp_enabled and training_args is not None and training_args.do_train: + if not fsdp_enabled and training_args.do_oneshot: + device_map = training_args.oneshot_device + logger.warning(f"Moving {model_path} to device {device_map} for One-Shot") + elif not fsdp_enabled: device_map = "auto" - model_kwargs = { "config": config, "cache_dir": model_args.cache_dir, @@ -276,7 +259,15 @@ def initialize_model_from_path( "device_map": device_map, "trust_remote_code": model_args.trust_remote_code_model, } - + teacher_device_map = None if fsdp_enabled else "auto" + teacher_kwargs = { + "config": teacher_config, + "cache_dir": model_args.cache_dir, + "use_auth_token": True if model_args.use_auth_token else None, + "torch_dtype": parse_dtype(model_args.precision), + "device_map": teacher_device_map, + "trust_remote_code": model_args.trust_remote_code_model, + } # this calls from_pretrained under the hood so should be FSDP safe # optimized models must be decompressed to carry out oneshot/train/etc @@ -292,30 +283,18 @@ def initialize_model_from_path( if "sequence_length" in model_kwargs: model.seqlen = model_kwargs["sequence_length"] - teacher = None - if training_args is not None: - teacher_device_map = None if fsdp_enabled else "auto" - teacher_kwargs = { - "config": teacher_config, - "cache_dir": model_args.cache_dir, - "use_auth_token": True if model_args.use_auth_token else None, - "torch_dtype": parse_dtype(model_args.precision), - "device_map": teacher_device_map, - "trust_remote_code": model_args.trust_remote_code_model, - } - - teacher = ( - AutoModelForCausalLM.from_pretrained( - model_args.distill_teacher, - **teacher_kwargs, - ) - if model_args.distill_teacher is not None - else None + teacher = ( + AutoModelForCausalLM.from_pretrained( + model_args.distill_teacher, + **teacher_kwargs, ) - if teacher is not None and "sequence_length" in teacher_kwargs: - teacher.seqlen = teacher_kwargs["sequence_length"] + if model_args.distill_teacher is not None + else None + ) + if teacher is not None and "sequence_length" in teacher_kwargs: + teacher.seqlen = teacher_kwargs["sequence_length"] - return model, teacher + return teacher, model_path, model def initialize_processor_from_path( @@ -323,7 +302,8 @@ def initialize_processor_from_path( model: PreTrainedModel, teacher: Optional[PreTrainedModel] = None, ) -> Processor: - processor_src = model_args.processor or get_processor_from_model(model, teacher) + processor_src = model_args.processor + processor_src = processor_src or get_shared_processor_src(model, teacher) # The use_fast=True option is not currently supported safely in Transformers # See: https://github.com/huggingface/transformers/pull/34836#issuecomment-2491809727 # noqa: E501 try: @@ -412,11 +392,10 @@ def main( model = model_args.model if isinstance(model, str) or isinstance(model, PosixPath): - (model, teacher) = initialize_model_from_path( + (teacher, _model_path, model) = initialize_model_from_path( model_args, training_args, ) - # patch a shared tensor bug in HF transformers # https://github.com/huggingface/transformers/issues/33689 patch_tied_tensors_bug(model) diff --git a/src/llmcompressor/transformers/finetune/training_args.py b/src/llmcompressor/transformers/finetune/training_args.py deleted file mode 100644 index c04fa2807..000000000 --- a/src/llmcompressor/transformers/finetune/training_args.py +++ /dev/null @@ -1,71 +0,0 @@ -from dataclasses import dataclass, field -from typing import List, Optional - -from transformers import TrainingArguments as HFTrainingArgs - -__all__ = ["TrainingArguments"] - - -@dataclass -class TrainingArguments(HFTrainingArgs): - """ - Training arguments specific to LLM Compressor Transformers workflow - - :param best_model_after_epoch (`int`, *optional*, defaults to None): - The epoch after which best model will be saved; used in conjunction - with `load_best_model_at_end` and `metric_for_best_model` training - arguments - """ - - recipe: Optional[str] = field( - default=None, - metadata={ - "help": "Path to a LLM Compressor sparsification recipe", - }, - ) - recipe_args: Optional[List[str]] = field( - default=None, - metadata={ - "help": ( - "List of recipe arguments to evaluate, of the format key1=value1 " - "key2=value2" - ) - }, - ) - save_compressed: Optional[bool] = field( - default=True, - metadata={"help": "Whether to compress sparse models during save"}, - ) - do_oneshot: Optional[bool] = field( - default=False, - metadata={"help": "Whether to run one-shot calibration"}, - ) - run_stages: Optional[bool] = field( - default=False, metadata={"help": "Whether to trigger recipe stage by stage"} - ) - oneshot_device: Optional[str] = field( - default="cuda:0", - metadata={"help": "Device to run oneshot calibration on"}, - ) - clear_sparse_session: Optional[bool] = field( - default=False, - metadata={"help": "Whether to clear CompressionSession data between runs."}, - ) - save_safetensors: Optional[bool] = field( - default=True, - metadata={ - "help": "Use safetensors saving and loading for state dicts instead of " - "default torch.load and torch.save." - }, - ) - output_dir: str = field( - default="./output", - metadata={ - "help": "The output directory where the model predictions and " - "checkpoints will be written." - }, - ) - - @property - def place_model_on_device(self): - return False