Skip to content

Commit

Permalink
refactor dataclass args
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Jan 27, 2025
1 parent 5cdae11 commit a7cf946
Show file tree
Hide file tree
Showing 21 changed files with 368 additions and 79 deletions.
8 changes: 4 additions & 4 deletions examples/trl_mixin/ex_trl_distillation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from sft_trainer import SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator

from llmcompressor.transformers import (
DataTrainingArguments,
TextGenerationDataset,
from llmcompressor.transformers import TextGenerationDataset
from llmcompressor.transformers.utils.arg_parser import (
DatasetArguments,
TrainingArguments,
)

Expand All @@ -21,7 +21,7 @@
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Load gsm8k using SparseML dataset tools
data_args = DataTrainingArguments(
data_args = DatasetArguments(
dataset="gsm8k", dataset_config_name="main", max_seq_length=512
)
dataset_manager = TextGenerationDataset.load_from_registry(
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from datasets.formatting.formatting import LazyRow
from loguru import logger

from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments
from llmcompressor.transformers.finetune.data.data_helpers import (
LABELS_MASK_VALUE,
get_custom_datasets_from_path,
get_raw_dataset,
)
from llmcompressor.transformers.utils.arg_parser import DatasetArguments
from llmcompressor.transformers.utils.preprocessing_functions import (
PreprocessingFunctionRegistry,
)
Expand Down Expand Up @@ -41,7 +41,7 @@ class TextGenerationDataset(RegistryMixin):

def __init__(
self,
data_args: DataTrainingArguments,
data_args: DatasetArguments,
split: str,
processor: Processor,
):
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/c4.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="c4")
Expand All @@ -18,7 +18,7 @@ class C4Dataset(TextGenerationDataset):
:param processor: processor or tokenizer to use on dataset
"""

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "allenai/c4"
data_args.text_column = "text"
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/cnn_dailymail.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="cnn_dailymail")
Expand All @@ -20,7 +20,7 @@ class CNNDailyMailDataset(TextGenerationDataset):

SAMPLE_TEMPLATE = "Article:\n{article}\n\n### Summarization:\n{highlights}\n"

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "cnn_dailymail"
data_args.dataset_config_name = "3.0.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="evolcodealpaca")
Expand All @@ -25,7 +25,7 @@ class EvolCodeAlpacaDataset(TextGenerationDataset):
"\n\n### Response:\n"
)

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "theblackcat102/evol-codealpaca-v1"
data_args.text_column = "text"
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/flickr_30k.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="flickr", alias="flickr30k")
Expand All @@ -31,7 +31,7 @@ class Flickr30K(TextGenerationDataset):
"{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
)

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "lmms-lab/flickr30k"

Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/gsm8k.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="gsm8k")
Expand All @@ -20,7 +20,7 @@ class GSM8KDataset(TextGenerationDataset):

GSM_TEMPLATE = "Question: {question}\nAnswer:"

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "gsm8k"
data_args.text_column = "text"
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/open_platypus.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="open_platypus")
Expand All @@ -28,7 +28,7 @@ class OpenPlatypusDataset(TextGenerationDataset):
"instruction}\n\n### Response:\n",
}

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "garage-bAInd/Open-Platypus"
data_args.text_column = "text"
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/ptb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="ptb")
Expand All @@ -18,7 +18,7 @@ class PtbDataset(TextGenerationDataset):
:param processor: processor or tokenizer to use on dataset
"""

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "ptb_text_only"
data_args.text_column = "sentence"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="ultrachat_200k")
Expand All @@ -33,7 +33,7 @@ class UltraChatDataset(TextGenerationDataset):
"{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
)

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "HuggingFaceH4/ultrachat_200k"
data_args.text_column = "messages"
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/wikitext.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="wikitext")
Expand All @@ -18,7 +18,7 @@ class WikiTextDataset(TextGenerationDataset):
:param processor: processor or tokenizer to use on dataset
"""

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "Salesforce/wikitext"
data_args.text_column = "text"
Expand Down
48 changes: 13 additions & 35 deletions src/llmcompressor/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
import torch
from loguru import logger
from torch.nn import Module
from torch.utils.data import DataLoader, IterableDataset
from torch.utils.data import IterableDataset
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import get_last_checkpoint

from llmcompressor.core import (
active_session,
apply,
callbacks,
create_session,
finalize,
Expand All @@ -36,8 +35,10 @@
from llmcompressor.utils.pytorch import qat_active

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments

from llmcompressor.transformers.utils.arg_parser import (
DatasetArguments,
ModelArguments,
)

__all__ = [
"SessionManagerMixIn",
Expand Down Expand Up @@ -68,12 +69,14 @@ def __init__(
self,
recipe: Optional[str] = None,
recipe_args: Optional[Union[Dict[str, Any], str]] = None,
data_args: Optional["DataTrainingArguments"] = None,
data_args: Optional["DatasetArguments"] = None,
model_args: Optional["ModelArguments"] = None,
teacher: Optional[Union[Module, str]] = None,
**kwargs,
):
self.recipe = recipe
self.recipe_args = recipe_args
self.model_args = model_args
self.teacher = teacher

# parse training and metadata args
Expand Down Expand Up @@ -374,16 +377,16 @@ def train(self, *args, stage: Optional[str] = None, **kwargs):
self.initialize_session(epoch=epoch, checkpoint=checkpoint, stage=stage)

# do not save checkpoints as compressed
original_save_compressed = self.args.save_compressed
self.args.save_compressed = False
original_save_compressed = self.model_args.save_compressed
self.model_args.save_compressed = False

# train with accelerator
self.accelerator.wait_for_everyone()
output = super().train(*args, **kwargs)
self.accelerator.wait_for_everyone()

# restore original setting for saving final model
self.args.save_compressed = original_save_compressed
self.model_args.save_compressed = original_save_compressed

# lifecycle
self.finalize_session()
Expand Down Expand Up @@ -428,31 +431,6 @@ def predict(self, *args, **kwargs):

return output

def one_shot(
self, calibration_data: Optional[DataLoader] = None, stage: Optional[str] = None
):
"""
Run oneshot calibration on the active model
:param stage: which stage of the recipe to run, or None to run whole recipe
:param calib_data: dataloader of calibration data
"""
apply(
recipe=self.recipe,
recipe_stage=stage,
recipe_args=self.recipe_args,
model=self.model,
calib_data=calibration_data,
start=-1,
copy_data=False,
accelerator=self.accelerator,
min_tokens_per_module=self.min_tokens_per_module,
)

# log model sparsity
# self.maybe_log_model_sparsification()
self.accelerator.wait_for_everyone()

def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False):
"""
Override of the save_model function and expects it to exist in the parent.
Expand All @@ -474,15 +452,15 @@ def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False):
if not is_fsdp_model(self.model):
self.model.save_pretrained(
output_dir,
save_compressed=self.args.save_compressed,
save_compressed=self.model_args.save_compressed,
safe_serialization=self.args.save_safetensors,
)
else: # FSDP model
save_pretrained_fsdp(
model=self.model,
accelerator=self.accelerator,
output_dir=output_dir,
save_compressed=self.args.save_compressed,
save_compressed=self.model_args.save_compressed,
save_safetensors=self.metadata.get("save_safetensors", False),
)

Expand Down
Loading

0 comments on commit a7cf946

Please sign in to comment.