Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Oneshot Refactor] dataclass Arguments #1103

Merged
merged 39 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
5cdae11
Dataclass Arg refactor -- recipe_args
horheynm Jan 27, 2025
a7cf946
refactor dataclass args
horheynm Jan 27, 2025
6859adc
apply changes only for dataclass args - recipe, model, dataset, training
horheynm Jan 28, 2025
7d19625
Merge branch 'main' into oneshot-refac-recipe_args
horheynm Jan 28, 2025
6a1e4b0
fix tests
horheynm Jan 28, 2025
44c67d7
Merge branch 'oneshot-refac-recipe_args' of github.com:vllm-project/l…
horheynm Jan 28, 2025
7bb2e9a
pass cli tests
horheynm Jan 28, 2025
0adb755
Merge branch 'main' into oneshot-refac-recipe_args
horheynm Jan 29, 2025
12a2167
remove redundant code
horheynm Jan 29, 2025
cacfbed
Merge branch 'main' into oneshot-refac-recipe_args
horheynm Jan 31, 2025
77205c0
comments
horheynm Jan 31, 2025
377a10b
Merge branch 'oneshot-refac-recipe_args' of github.com:vllm-project/l…
horheynm Jan 31, 2025
84ddf07
add type annotations to private func
horheynm Jan 31, 2025
3770e6c
fix tests
horheynm Feb 3, 2025
1f3110b
Merge branch 'main' into oneshot-refac-recipe_args
horheynm Feb 4, 2025
dbf7f8c
move to util
horheynm Feb 4, 2025
656824e
fix tests
horheynm Feb 5, 2025
48e382f
remove redudant code
horheynm Feb 5, 2025
be31960
examples TrainingArguments movement
horheynm Feb 5, 2025
1253435
Merge branch 'main' into oneshot-refac-recipe_args
horheynm Feb 5, 2025
f4491be
revert to only refactor wrt to dataclass
horheynm Feb 6, 2025
13ee157
remove unnec code
horheynm Feb 6, 2025
48f531f
Merge branch 'main' into oneshot-refac-recipe_args
horheynm Feb 6, 2025
ce42137
Merge branch 'main' into oneshot-refac-recipe_args
horheynm Feb 6, 2025
2fb8212
change optional to required in session_mixin
horheynm Feb 6, 2025
ef8fae0
Merge branch 'oneshot-refac-recipe_args' of github.com:vllm-project/l…
horheynm Feb 6, 2025
f0fc214
fix
horheynm Feb 6, 2025
6e4c8cc
fix test
horheynm Feb 7, 2025
bcbfd35
comments
horheynm Feb 10, 2025
049ddec
add
horheynm Feb 10, 2025
e671817
Merge branch 'main' into oneshot-refac-recipe_args
horheynm Feb 10, 2025
68dd3b4
consistency
horheynm Feb 10, 2025
63862c3
Update src/llmcompressor/arg_parser/README.md
horheynm Feb 10, 2025
afb8efa
comment
horheynm Feb 10, 2025
d50baba
comments
horheynm Feb 10, 2025
9e26589
rename data_arguments to dataset_arguments
horheynm Feb 10, 2025
7f49448
comments
horheynm Feb 10, 2025
a55a427
change directory name
horheynm Feb 10, 2025
319d1bd
fix
horheynm Feb 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/trl_mixin/ex_trl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DataCollatorForCompletionOnlyLM

from llmcompressor.transformers import TrainingArguments
from llmcompressor.transformers.utils.arg_parser import TrainingArguments

model_path = "neuralmagic/Llama-2-7b-pruned50-retrained"
output_dir = "./output_trl_sft_test_7b_gsm8k_sft_data"
Expand Down
8 changes: 4 additions & 4 deletions examples/trl_mixin/ex_trl_distillation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from sft_trainer import SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator

from llmcompressor.transformers import (
DataTrainingArguments,
TextGenerationDataset,
from llmcompressor.transformers import TextGenerationDataset
from llmcompressor.transformers.utils.arg_parser import (
DatasetArguments,
TrainingArguments,
)

Expand All @@ -21,7 +21,7 @@
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Load gsm8k using SparseML dataset tools
data_args = DataTrainingArguments(
data_args = DatasetArguments(
dataset="gsm8k", dataset_config_name="main", max_seq_length=512
)
dataset_manager = TextGenerationDataset.load_from_registry(
Expand Down
2 changes: 1 addition & 1 deletion examples/trl_mixin/sft_trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from trl import SFTConfig as TRLSFTConfig
from trl import SFTTrainer as TRLSFTTrainer

from llmcompressor.transformers import TrainingArguments
from llmcompressor.transformers.utils.arg_parser import TrainingArguments
from llmcompressor.transformers.finetune.session_mixin import SessionManagerMixIn

__all__ = ["SFTTrainer"]
Expand Down
7 changes: 4 additions & 3 deletions src/llmcompressor/transformers/finetune/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ train(

Finetuning arguments are split up into 3 groups:
horheynm marked this conversation as resolved.
Show resolved Hide resolved

* ModelArguments: `src/llmcompressor/transformers/finetune/model_args.py`
* TrainingArguments: `src/llmcompressor/transformers/finetune/training_args.py`
* DataTrainingArguments: `src/llmcompressor/transformers/finetune/data/data_training_args.py`
* ModelArguments: `src/llmcompressor/transformers/utils/arg_parser/model_arguments.py`
* TrainingArguments: `src/llmcompressor/transformers/utils/arg_parser/training_arguments.py`
* DatasetArguments: `src/llmcompressor/transformers/utils/arg_parser/data_arguments.py`
* RecipeArguments: `src/llmcompressor/transformers/utils/arg_parser/recipe_arguments.py`


## Running One-Shot with FSDP
Expand Down
4 changes: 1 addition & 3 deletions src/llmcompressor/transformers/finetune/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# flake8: noqa

from .data import DataTrainingArguments, TextGenerationDataset
from .model_args import ModelArguments
from .data import TextGenerationDataset
from .session_mixin import SessionManagerMixIn
from .text_generation import apply, compress, eval, oneshot, train
from .training_args import TrainingArguments
1 change: 0 additions & 1 deletion src/llmcompressor/transformers/finetune/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from .c4 import C4Dataset
from .cnn_dailymail import CNNDailyMailDataset
from .custom import CustomDataset
from .data_args import DataTrainingArguments
from .evolcodealpaca import EvolCodeAlpacaDataset
from .flickr_30k import Flickr30K
from .gsm8k import GSM8KDataset
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from datasets.formatting.formatting import LazyRow
from loguru import logger

from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments
from llmcompressor.transformers.finetune.data.data_helpers import (
LABELS_MASK_VALUE,
get_custom_datasets_from_path,
get_raw_dataset,
)
from llmcompressor.transformers.utils.arg_parser import DatasetArguments
from llmcompressor.transformers.utils.preprocessing_functions import (
PreprocessingFunctionRegistry,
)
Expand Down Expand Up @@ -41,7 +41,7 @@ class TextGenerationDataset(RegistryMixin):

def __init__(
self,
data_args: DataTrainingArguments,
data_args: DatasetArguments,
split: str,
processor: Processor,
):
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/c4.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="c4")
Expand All @@ -18,7 +18,7 @@ class C4Dataset(TextGenerationDataset):
:param processor: processor or tokenizer to use on dataset
"""

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "allenai/c4"
data_args.text_column = "text"
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/cnn_dailymail.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="cnn_dailymail")
Expand All @@ -20,7 +20,7 @@ class CNNDailyMailDataset(TextGenerationDataset):

SAMPLE_TEMPLATE = "Article:\n{article}\n\n### Summarization:\n{highlights}\n"

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "cnn_dailymail"
data_args.dataset_config_name = "3.0.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="evolcodealpaca")
Expand All @@ -25,7 +25,7 @@ class EvolCodeAlpacaDataset(TextGenerationDataset):
"\n\n### Response:\n"
)

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "theblackcat102/evol-codealpaca-v1"
data_args.text_column = "text"
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/flickr_30k.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="flickr", alias="flickr30k")
Expand All @@ -31,7 +31,7 @@ class Flickr30K(TextGenerationDataset):
"{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
)

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "lmms-lab/flickr30k"

Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/gsm8k.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="gsm8k")
Expand All @@ -20,7 +20,7 @@ class GSM8KDataset(TextGenerationDataset):

GSM_TEMPLATE = "Question: {question}\nAnswer:"

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "gsm8k"
data_args.text_column = "text"
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/open_platypus.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="open_platypus")
Expand All @@ -28,7 +28,7 @@ class OpenPlatypusDataset(TextGenerationDataset):
"instruction}\n\n### Response:\n",
}

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "garage-bAInd/Open-Platypus"
data_args.text_column = "text"
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/ptb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="ptb")
Expand All @@ -18,7 +18,7 @@ class PtbDataset(TextGenerationDataset):
:param processor: processor or tokenizer to use on dataset
"""

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "ptb_text_only"
data_args.text_column = "sentence"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="ultrachat_200k")
Expand All @@ -33,7 +33,7 @@ class UltraChatDataset(TextGenerationDataset):
"{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
)

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "HuggingFaceH4/ultrachat_200k"
data_args.text_column = "messages"
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/wikitext.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="wikitext")
Expand All @@ -18,7 +18,7 @@ class WikiTextDataset(TextGenerationDataset):
:param processor: processor or tokenizer to use on dataset
"""

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "Salesforce/wikitext"
data_args.text_column = "text"
Expand Down
22 changes: 12 additions & 10 deletions src/llmcompressor/transformers/finetune/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
from llmcompressor.pytorch.utils import tensors_to_device
from llmcompressor.recipe import Recipe, StageRunType
from llmcompressor.transformers.finetune.data import TextGenerationDataset
from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments
from llmcompressor.transformers.finetune.data.data_helpers import (
format_calibration_data,
make_dataset_splits,
)
from llmcompressor.transformers.finetune.model_args import ModelArguments
from llmcompressor.transformers.finetune.training_args import TrainingArguments
from llmcompressor.transformers.utils.arg_parser import (
DatasetArguments,
ModelArguments,
RecipeArguments,
TrainingArguments,
)
from llmcompressor.typing import Processor
from llmcompressor.utils.fsdp.helpers import is_fsdp_model, save_model_and_recipe

Expand All @@ -46,13 +49,15 @@ class StageRunner:

def __init__(
self,
data_args: "DataTrainingArguments",
data_args: "DatasetArguments",
model_args: "ModelArguments",
training_args: "TrainingArguments",
recipe_args: "RecipeArguments",
):
self._data_args = data_args
self._model_args = model_args
self._training_args = training_args
self._recipe_args = recipe_args

self.datasets = {}
self.trainer = None
Expand Down Expand Up @@ -214,7 +219,7 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None):
:param checkpoint: optional checkpoint to pick up a stage from
"""

recipe_obj = Recipe.create_instance(self._training_args.recipe)
recipe_obj = Recipe.create_instance(self._recipe_args.recipe)
with self.trainer.accelerator.main_process_first():
checkpoint_dir = self._model_args.model
completed_stages = get_completed_stages(checkpoint_dir)
Expand Down Expand Up @@ -256,16 +261,13 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None):
self.train(checkpoint=checkpoint, stage=stage_name)
checkpoint = None

if (
self._training_args.output_dir
!= TrainingArguments.__dataclass_fields__["output_dir"].default
):
if self._training_args.output_dir:
horheynm marked this conversation as resolved.
Show resolved Hide resolved
save_model_and_recipe(
model=self.trainer.model,
save_path=self._output_dir,
processor=self.processor,
save_safetensors=self._training_args.save_safetensors,
save_compressed=self._training_args.save_compressed,
save_compressed=self._model_args.save_compressed,
)

# save stage to checkpoint dir
Expand Down
21 changes: 12 additions & 9 deletions src/llmcompressor/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@
from llmcompressor.utils.pytorch import qat_active

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments

from llmcompressor.transformers.utils.arg_parser import (
DatasetArguments,
ModelArguments,
)

__all__ = [
"SessionManagerMixIn",
Expand Down Expand Up @@ -68,12 +70,14 @@ def __init__(
self,
recipe: Optional[str] = None,
recipe_args: Optional[Union[Dict[str, Any], str]] = None,
data_args: Optional["DataTrainingArguments"] = None,
data_args: Optional["DatasetArguments"] = None,
model_args: Optional["ModelArguments"] = None,
teacher: Optional[Union[Module, str]] = None,
**kwargs,
):
self.recipe = recipe
self.recipe_args = recipe_args
self.model_args = model_args
self.teacher = teacher

# parse training and metadata args
Expand Down Expand Up @@ -374,16 +378,16 @@ def train(self, *args, stage: Optional[str] = None, **kwargs):
self.initialize_session(epoch=epoch, checkpoint=checkpoint, stage=stage)

# do not save checkpoints as compressed
original_save_compressed = self.args.save_compressed
self.args.save_compressed = False
original_save_compressed = self.model_args.save_compressed
self.model_args.save_compressed = False

# train with accelerator
self.accelerator.wait_for_everyone()
output = super().train(*args, **kwargs)
self.accelerator.wait_for_everyone()

# restore original setting for saving final model
self.args.save_compressed = original_save_compressed
self.model_args.save_compressed = original_save_compressed

# lifecycle
self.finalize_session()
Expand Down Expand Up @@ -433,7 +437,6 @@ def one_shot(
):
"""
Run oneshot calibration on the active model

:param stage: which stage of the recipe to run, or None to run whole recipe
:param calib_data: dataloader of calibration data
"""
Expand Down Expand Up @@ -474,15 +477,15 @@ def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False):
if not is_fsdp_model(self.model):
self.model.save_pretrained(
output_dir,
save_compressed=self.args.save_compressed,
save_compressed=self.model_args.save_compressed,
safe_serialization=self.args.save_safetensors,
)
else: # FSDP model
save_pretrained_fsdp(
model=self.model,
accelerator=self.accelerator,
output_dir=output_dir,
save_compressed=self.args.save_compressed,
save_compressed=self.model_args.save_compressed,
save_safetensors=self.metadata.get("save_safetensors", False),
)

Expand Down
Loading
Loading