Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Oneshot Refactor] dataclass Arguments #1103

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/trl_mixin/ex_trl_distillation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from sft_trainer import SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator

from llmcompressor.transformers import (
DataTrainingArguments,
TextGenerationDataset,
from llmcompressor.transformers import TextGenerationDataset
from llmcompressor.transformers.utils.arg_parser import (
DatasetArguments,
TrainingArguments,
)

Expand All @@ -21,7 +21,7 @@
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Load gsm8k using SparseML dataset tools
data_args = DataTrainingArguments(
data_args = DatasetArguments(
dataset="gsm8k", dataset_config_name="main", max_seq_length=512
)
dataset_manager = TextGenerationDataset.load_from_registry(
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from datasets.formatting.formatting import LazyRow
from loguru import logger

from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments
from llmcompressor.transformers.finetune.data.data_helpers import (
LABELS_MASK_VALUE,
get_custom_datasets_from_path,
get_raw_dataset,
)
from llmcompressor.transformers.utils.arg_parser import DatasetArguments
from llmcompressor.transformers.utils.preprocessing_functions import (
PreprocessingFunctionRegistry,
)
Expand Down Expand Up @@ -41,7 +41,7 @@ class TextGenerationDataset(RegistryMixin):

def __init__(
self,
data_args: DataTrainingArguments,
data_args: DatasetArguments,
split: str,
processor: Processor,
):
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/c4.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="c4")
Expand All @@ -18,7 +18,7 @@ class C4Dataset(TextGenerationDataset):
:param processor: processor or tokenizer to use on dataset
"""

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "allenai/c4"
data_args.text_column = "text"
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/cnn_dailymail.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="cnn_dailymail")
Expand All @@ -20,7 +20,7 @@ class CNNDailyMailDataset(TextGenerationDataset):

SAMPLE_TEMPLATE = "Article:\n{article}\n\n### Summarization:\n{highlights}\n"

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "cnn_dailymail"
data_args.dataset_config_name = "3.0.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="evolcodealpaca")
Expand All @@ -25,7 +25,7 @@ class EvolCodeAlpacaDataset(TextGenerationDataset):
"\n\n### Response:\n"
)

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "theblackcat102/evol-codealpaca-v1"
data_args.text_column = "text"
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/flickr_30k.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="flickr", alias="flickr30k")
Expand All @@ -31,7 +31,7 @@ class Flickr30K(TextGenerationDataset):
"{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
)

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "lmms-lab/flickr30k"

Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/gsm8k.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="gsm8k")
Expand All @@ -20,7 +20,7 @@ class GSM8KDataset(TextGenerationDataset):

GSM_TEMPLATE = "Question: {question}\nAnswer:"

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "gsm8k"
data_args.text_column = "text"
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/open_platypus.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="open_platypus")
Expand All @@ -28,7 +28,7 @@ class OpenPlatypusDataset(TextGenerationDataset):
"instruction}\n\n### Response:\n",
}

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "garage-bAInd/Open-Platypus"
data_args.text_column = "text"
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/ptb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="ptb")
Expand All @@ -18,7 +18,7 @@ class PtbDataset(TextGenerationDataset):
:param processor: processor or tokenizer to use on dataset
"""

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "ptb_text_only"
data_args.text_column = "sentence"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="ultrachat_200k")
Expand All @@ -33,7 +33,7 @@ class UltraChatDataset(TextGenerationDataset):
"{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
)

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "HuggingFaceH4/ultrachat_200k"
data_args.text_column = "messages"
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/wikitext.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="wikitext")
Expand All @@ -18,7 +18,7 @@ class WikiTextDataset(TextGenerationDataset):
:param processor: processor or tokenizer to use on dataset
"""

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "Salesforce/wikitext"
data_args.text_column = "text"
Expand Down
40 changes: 29 additions & 11 deletions src/llmcompressor/transformers/finetune/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,20 @@
from llmcompressor.pytorch.utils import tensors_to_device
from llmcompressor.recipe import Recipe, StageRunType
from llmcompressor.transformers.finetune.data import TextGenerationDataset
from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments
from llmcompressor.transformers.finetune.data.data_helpers import (
format_calibration_data,
make_dataset_splits,
)
from llmcompressor.transformers.finetune.model_args import ModelArguments
from llmcompressor.transformers.finetune.training_args import TrainingArguments
from llmcompressor.transformers.utils.arg_parser import (
DatasetArguments,
ModelArguments,
RecipeArguments,
TrainingArguments,
)
from llmcompressor.transformers.utils.arg_parser.training_arguments import (
DEFAULT_OUTPUT_DIR,
)
from llmcompressor.transformers.utils.arg_parser.utils import get_dataclass_as_dict
from llmcompressor.typing import Processor
from llmcompressor.utils.fsdp.helpers import is_fsdp_model, save_model_and_recipe

Expand All @@ -46,13 +53,15 @@ class StageRunner:

def __init__(
self,
data_args: "DataTrainingArguments",
data_args: "DatasetArguments",
model_args: "ModelArguments",
training_args: "TrainingArguments",
recipe_args: "RecipeArguments",
):
self._data_args = data_args
self._model_args = model_args
self._training_args = training_args
self._recipe_args = recipe_args

self.datasets = {}
self.trainer = None
Expand Down Expand Up @@ -214,7 +223,7 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None):
:param checkpoint: optional checkpoint to pick up a stage from
"""

recipe_obj = Recipe.create_instance(self._training_args.recipe)
recipe_obj = Recipe.create_instance(self._recipe_args.recipe)
with self.trainer.accelerator.main_process_first():
checkpoint_dir = self._model_args.model
completed_stages = get_completed_stages(checkpoint_dir)
Expand Down Expand Up @@ -251,21 +260,30 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None):

# run stage
if run_type is StageRunType.ONESHOT:
self.one_shot(stage=stage_name)
from llmcompressor.transformers.calibration import Oneshot
horheynm marked this conversation as resolved.
Show resolved Hide resolved

model = get_session_model()
self._model_args.model = model

oneshot = Oneshot(
output_dir=self._training_args.output_dir,
**get_dataclass_as_dict(self._model_args, ModelArguments),
**get_dataclass_as_dict(self._data_args, DatasetArguments),
**get_dataclass_as_dict(self._recipe_args, RecipeArguments),
)

oneshot.run(stage_name=stage_name)
elif run_type is StageRunType.TRAIN:
self.train(checkpoint=checkpoint, stage=stage_name)
checkpoint = None

if (
self._training_args.output_dir
!= TrainingArguments.__dataclass_fields__["output_dir"].default
):
if self._training_args.output_dir != DEFAULT_OUTPUT_DIR:
horheynm marked this conversation as resolved.
Show resolved Hide resolved
save_model_and_recipe(
model=self.trainer.model,
save_path=self._output_dir,
processor=self.processor,
save_safetensors=self._training_args.save_safetensors,
save_compressed=self._training_args.save_compressed,
save_compressed=self._model_args.save_compressed,
)

# save stage to checkpoint dir
Expand Down
Loading
Loading