Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Oneshot Refactor] dataclass Arguments #1103

Merged
merged 39 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
5cdae11
Dataclass Arg refactor -- recipe_args
horheynm Jan 27, 2025
a7cf946
refactor dataclass args
horheynm Jan 27, 2025
6859adc
apply changes only for dataclass args - recipe, model, dataset, training
horheynm Jan 28, 2025
7d19625
Merge branch 'main' into oneshot-refac-recipe_args
horheynm Jan 28, 2025
6a1e4b0
fix tests
horheynm Jan 28, 2025
44c67d7
Merge branch 'oneshot-refac-recipe_args' of github.com:vllm-project/l…
horheynm Jan 28, 2025
7bb2e9a
pass cli tests
horheynm Jan 28, 2025
0adb755
Merge branch 'main' into oneshot-refac-recipe_args
horheynm Jan 29, 2025
12a2167
remove redundant code
horheynm Jan 29, 2025
cacfbed
Merge branch 'main' into oneshot-refac-recipe_args
horheynm Jan 31, 2025
77205c0
comments
horheynm Jan 31, 2025
377a10b
Merge branch 'oneshot-refac-recipe_args' of github.com:vllm-project/l…
horheynm Jan 31, 2025
84ddf07
add type annotations to private func
horheynm Jan 31, 2025
3770e6c
fix tests
horheynm Feb 3, 2025
1f3110b
Merge branch 'main' into oneshot-refac-recipe_args
horheynm Feb 4, 2025
dbf7f8c
move to util
horheynm Feb 4, 2025
656824e
fix tests
horheynm Feb 5, 2025
48e382f
remove redudant code
horheynm Feb 5, 2025
be31960
examples TrainingArguments movement
horheynm Feb 5, 2025
1253435
Merge branch 'main' into oneshot-refac-recipe_args
horheynm Feb 5, 2025
f4491be
revert to only refactor wrt to dataclass
horheynm Feb 6, 2025
13ee157
remove unnec code
horheynm Feb 6, 2025
48f531f
Merge branch 'main' into oneshot-refac-recipe_args
horheynm Feb 6, 2025
ce42137
Merge branch 'main' into oneshot-refac-recipe_args
horheynm Feb 6, 2025
2fb8212
change optional to required in session_mixin
horheynm Feb 6, 2025
ef8fae0
Merge branch 'oneshot-refac-recipe_args' of github.com:vllm-project/l…
horheynm Feb 6, 2025
f0fc214
fix
horheynm Feb 6, 2025
6e4c8cc
fix test
horheynm Feb 7, 2025
bcbfd35
comments
horheynm Feb 10, 2025
049ddec
add
horheynm Feb 10, 2025
e671817
Merge branch 'main' into oneshot-refac-recipe_args
horheynm Feb 10, 2025
68dd3b4
consistency
horheynm Feb 10, 2025
63862c3
Update src/llmcompressor/arg_parser/README.md
horheynm Feb 10, 2025
afb8efa
comment
horheynm Feb 10, 2025
d50baba
comments
horheynm Feb 10, 2025
9e26589
rename data_arguments to dataset_arguments
horheynm Feb 10, 2025
7f49448
comments
horheynm Feb 10, 2025
a55a427
change directory name
horheynm Feb 10, 2025
319d1bd
fix
horheynm Feb 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/trl_mixin/ex_trl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DataCollatorForCompletionOnlyLM

from llmcompressor.transformers import TrainingArguments
from llmcompressor.args import TrainingArguments

model_path = "neuralmagic/Llama-2-7b-pruned50-retrained"
output_dir = "./output_trl_sft_test_7b_gsm8k_sft_data"
Expand Down
9 changes: 3 additions & 6 deletions examples/trl_mixin/ex_trl_distillation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from sft_trainer import SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator

from llmcompressor.transformers import (
DataTrainingArguments,
TextGenerationDataset,
TrainingArguments,
)
from llmcompressor.args import DatasetArguments, TrainingArguments
from llmcompressor.transformers import TextGenerationDataset

model_path = "neuralmagic/Llama-2-7b-pruned50-retrained"
teacher_path = "neuralmagic/Llama-2-7b-gsm8k"
Expand All @@ -21,7 +18,7 @@
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Load gsm8k using SparseML dataset tools
data_args = DataTrainingArguments(
data_args = DatasetArguments(
dataset="gsm8k", dataset_config_name="main", max_seq_length=512
)
dataset_manager = TextGenerationDataset.load_from_registry(
Expand Down
2 changes: 1 addition & 1 deletion examples/trl_mixin/sft_trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from trl import SFTConfig as TRLSFTConfig
from trl import SFTTrainer as TRLSFTTrainer

from llmcompressor.transformers import TrainingArguments
from llmcompressor.args import TrainingArguments
from llmcompressor.transformers.finetune.session_mixin import SessionManagerMixIn

__all__ = ["SFTTrainer"]
Expand Down
45 changes: 45 additions & 0 deletions src/llmcompressor/args/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Input arguments for `oneshot`, `train`, `eval` entrypoints

Parsers in `llm-compressor` define the input arguments required for various entry points, including `oneshot`, `train`, and `eval`.

Each entry point (e.g., oneshot) carries out its logic based on the provided input arguments, `model`, `recipe`, and `dataset`.

```python
from llmcompressor.transformers import oneshot

model = ...
recipe = ...
dataset = ...
oneshot(model=model, recipe=recipe, dataset=dataset)
```

In addition, users can futher control execution by providing additional arguments. For example, to save the optimized model after completion, the `output_dir` parameter can be specified:

```python
oneshot(
...,
output_dir=...,
)
```

These input arguments can be overloaded into the function signature and will be parsed using Hugging Face's [argument parser](https://github.com/huggingface/transformers/blob/main/src/transformers/hf_argparser.py). The parsers define the acceptable inputs; therefore any arguments to be passed in must be defined.

`llm-compressor` uses four parsers, located in `llm_compressor/arg_parser`:
* ModelArguments
* DatasetArguments
* RecipeArguments
* TrainingArguments


## ModelArguments
Handles model loading and saving. For example, `ModelArguments.model` can be a Hugging Face model identifier or an instance of `AutoModelForCausalLM`. The `save_compressed` flag is a boolean that determines whether the model is saved in compressed safetensors format to minimize disk usage.

## DataArguments
Manages data loading and preprocessing. The dataset argument can specify a Hugging Face dataset stub or a local dataset compatible with [`load_dataset`](https://github.com/huggingface/datasets/blob/3a4e74a9ace62ecd5c9cde7dcb6bcabd65cc7857/src/datasets/load.py#L1905). The preprocessing_func is a callable function that applies custom logic, such as formatting the data using a chat template.

## RecipeArguments
Defines the model recipe. A `recipe` consists of user-defined instructions for optimizing the model. Examples of recipes can be found in the `/examples` directory.

## TrainingArguments
Specifies training parameters based on Hugging Face's [TrainingArguments class](https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py). These parameters include settings like learning rate (`learning_rate`), and the optimizer to use (`optim`).

6 changes: 6 additions & 0 deletions src/llmcompressor/args/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# flake8: noqa

from .dataset_arguments import DatasetArguments
from .model_arguments import ModelArguments
from .recipe_arguments import RecipeArguments
from .training_arguments import TrainingArguments
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@dataclass
class DVCDatasetTrainingArguments:
class DVCDatasetArguments:
"""
Arguments for training using DVC
"""
Expand All @@ -17,7 +17,7 @@ class DVCDatasetTrainingArguments:


@dataclass
class CustomDataTrainingArguments(DVCDatasetTrainingArguments):
class CustomDatasetArguments(DVCDatasetArguments):
"""
Arguments for training using custom datasets
"""
Expand Down Expand Up @@ -67,10 +67,10 @@ class CustomDataTrainingArguments(DVCDatasetTrainingArguments):


@dataclass
class DataTrainingArguments(CustomDataTrainingArguments):
class DatasetArguments(CustomDatasetArguments):
"""
Arguments pertaining to what data we are going to input our model for
training and eval
calibration, training or eval
Using `HfArgumentParser` we can turn this class into argparse
arguments to be able to specify them on the command line
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from
Model variables used for oneshot calibration, finetuning and
stage runners (sequential run of oneshot and finetune).

"""

model: str = field(
Expand Down Expand Up @@ -44,17 +46,7 @@ class ModelArguments:
default=None,
metadata={"help": "Where to store the pretrained data from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizers. Default True"},
)
model_revision: str = field(
default="main",
metadata={
"help": "The specific model version to use "
"(can be a branch name, tag name or commit id)"
},
)

use_auth_token: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -83,3 +75,18 @@ class ModelArguments:
"repositories you trust and in which you have read the code"
},
)
save_compressed: Optional[bool] = field(
default=True,
metadata={"help": "Whether to compress sparse models during save"},
)
oneshot_device: Optional[str] = field(
default="cuda:0",
metadata={"help": "Device to run oneshot calibration on"},
)
model_revision: str = field(
default="main",
metadata={
"help": "The specific model version to use "
"(can be a branch name, tag name or commit id)"
},
)
32 changes: 32 additions & 0 deletions src/llmcompressor/args/recipe_arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from dataclasses import dataclass, field
from typing import List, Optional


@dataclass
class RecipeArguments:
"""Recipe and session variables"""

recipe: Optional[str] = field(
default=None,
metadata={
"help": "Path to a LLM Compressor sparsification recipe",
},
)
recipe_args: Optional[List[str]] = field(
default=None,
metadata={
"help": (
"List of recipe arguments to evaluate, of the format key1=value1 "
"key2=value2"
)
},
)
clear_sparse_session: Optional[bool] = field(
default=False,
metadata={
"help": (
"Whether to clear CompressionSession/CompressionLifecycle ",
"data between runs.",
)
},
)
32 changes: 32 additions & 0 deletions src/llmcompressor/args/training_arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from dataclasses import dataclass, field
from typing import Optional

from transformers import TrainingArguments as HFTrainingArgs

__all__ = [
"TrainingArguments",
]


@dataclass
class TrainingArguments(HFTrainingArgs):
"""
Training arguments specific to LLM Compressor Transformers workflow using
HFTrainingArgs as base class
"""

do_oneshot: Optional[bool] = field(
default=False,
metadata={"help": "Whether to run one-shot calibration in stages"},
)
run_stages: Optional[bool] = field(
default=False, metadata={"help": "Whether to trigger recipe stage by stage"}
)
output_dir: str = field(
default="./output",
metadata={
"help": "The output directory where the model predictions and "
"checkpoints will be written."
},
)
7 changes: 4 additions & 3 deletions src/llmcompressor/transformers/finetune/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ train(

Finetuning arguments are split up into 3 groups:
horheynm marked this conversation as resolved.
Show resolved Hide resolved

* ModelArguments: `src/llmcompressor/transformers/finetune/model_args.py`
* TrainingArguments: `src/llmcompressor/transformers/finetune/training_args.py`
* DataTrainingArguments: `src/llmcompressor/transformers/finetune/data/data_training_args.py`
* ModelArguments: `src/llmcompressor/transformers/utils/arg_parser/model_arguments.py`
* TrainingArguments: `src/llmcompressor/transformers/utils/arg_parser/training_arguments.py`
* DatasetArguments: `src/llmcompressor/transformers/utils/arg_parser/dataset_arguments.py`
* RecipeArguments: `src/llmcompressor/transformers/utils/arg_parser/recipe_arguments.py`


## Running One-Shot with FSDP
Expand Down
4 changes: 1 addition & 3 deletions src/llmcompressor/transformers/finetune/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# flake8: noqa

from .data import DataTrainingArguments, TextGenerationDataset
from .model_args import ModelArguments
from .data import TextGenerationDataset
from .session_mixin import SessionManagerMixIn
from .text_generation import apply, compress, eval, oneshot, train
from .training_args import TrainingArguments
1 change: 0 additions & 1 deletion src/llmcompressor/transformers/finetune/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from .c4 import C4Dataset
from .cnn_dailymail import CNNDailyMailDataset
from .custom import CustomDataset
from .data_args import DataTrainingArguments
from .evolcodealpaca import EvolCodeAlpacaDataset
from .flickr_30k import Flickr30K
from .gsm8k import GSM8KDataset
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datasets.formatting.formatting import LazyRow
from loguru import logger

from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments
from llmcompressor.args import DatasetArguments
from llmcompressor.transformers.finetune.data.data_helpers import (
LABELS_MASK_VALUE,
get_custom_datasets_from_path,
Expand Down Expand Up @@ -41,7 +41,7 @@ class TextGenerationDataset(RegistryMixin):

def __init__(
self,
data_args: DataTrainingArguments,
data_args: DatasetArguments,
split: str,
processor: Processor,
):
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/c4.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.args import DatasetArguments


@TextGenerationDataset.register(name="c4")
Expand All @@ -18,7 +18,7 @@ class C4Dataset(TextGenerationDataset):
:param processor: processor or tokenizer to use on dataset
"""

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "allenai/c4"
data_args.text_column = "text"
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/cnn_dailymail.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.args import DatasetArguments


@TextGenerationDataset.register(name="cnn_dailymail")
Expand All @@ -20,7 +20,7 @@ class CNNDailyMailDataset(TextGenerationDataset):

SAMPLE_TEMPLATE = "Article:\n{article}\n\n### Summarization:\n{highlights}\n"

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "cnn_dailymail"
data_args.dataset_config_name = "3.0.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.args import DatasetArguments


@TextGenerationDataset.register(name="evolcodealpaca")
Expand All @@ -25,7 +25,7 @@ class EvolCodeAlpacaDataset(TextGenerationDataset):
"\n\n### Response:\n"
)

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "theblackcat102/evol-codealpaca-v1"
data_args.text_column = "text"
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/flickr_30k.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.args import DatasetArguments


@TextGenerationDataset.register(name="flickr", alias="flickr30k")
Expand All @@ -31,7 +31,7 @@ class Flickr30K(TextGenerationDataset):
"{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
)

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "lmms-lab/flickr30k"

Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/gsm8k.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.args import DatasetArguments


@TextGenerationDataset.register(name="gsm8k")
Expand All @@ -20,7 +20,7 @@ class GSM8KDataset(TextGenerationDataset):

GSM_TEMPLATE = "Question: {question}\nAnswer:"

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "gsm8k"
data_args.text_column = "text"
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/open_platypus.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.args import DatasetArguments


@TextGenerationDataset.register(name="open_platypus")
Expand All @@ -28,7 +28,7 @@ class OpenPlatypusDataset(TextGenerationDataset):
"instruction}\n\n### Response:\n",
}

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "garage-bAInd/Open-Platypus"
data_args.text_column = "text"
Expand Down
Loading