Skip to content

Commit

Permalink
apply changes only for dataclass args - recipe, model, dataset, training
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Jan 28, 2025
1 parent a7cf946 commit 6859adc
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 434 deletions.
4 changes: 1 addition & 3 deletions src/llmcompressor/transformers/finetune/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# flake8: noqa

from .data import DataTrainingArguments, TextGenerationDataset
from .model_args import ModelArguments
from .data import TextGenerationDataset
from .session_mixin import SessionManagerMixIn
from .text_generation import apply, compress, eval, oneshot, train
from .training_args import TrainingArguments
1 change: 0 additions & 1 deletion src/llmcompressor/transformers/finetune/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from .c4 import C4Dataset
from .cnn_dailymail import CNNDailyMailDataset
from .custom import CustomDataset
from .data_args import DataTrainingArguments
from .evolcodealpaca import EvolCodeAlpacaDataset
from .flickr_30k import Flickr30K
from .gsm8k import GSM8KDataset
Expand Down
189 changes: 0 additions & 189 deletions src/llmcompressor/transformers/finetune/data/data_args.py

This file was deleted.

85 changes: 0 additions & 85 deletions src/llmcompressor/transformers/finetune/model_args.py

This file was deleted.

15 changes: 1 addition & 14 deletions src/llmcompressor/transformers/finetune/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from llmcompressor.transformers.utils.arg_parser.training_arguments import (
DEFAULT_OUTPUT_DIR,
)
from llmcompressor.transformers.utils.arg_parser.utils import get_dataclass_as_dict
from llmcompressor.typing import Processor
from llmcompressor.utils.fsdp.helpers import is_fsdp_model, save_model_and_recipe

Expand Down Expand Up @@ -260,19 +259,7 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None):

# run stage
if run_type is StageRunType.ONESHOT:
from llmcompressor.transformers.calibration import Oneshot

model = get_session_model()
self._model_args.model = model

oneshot = Oneshot(
output_dir=self._training_args.output_dir,
**get_dataclass_as_dict(self._model_args, ModelArguments),
**get_dataclass_as_dict(self._data_args, DatasetArguments),
**get_dataclass_as_dict(self._recipe_args, RecipeArguments),
)

oneshot.run(stage_name=stage_name)
self.one_shot(stage=stage_name)
elif run_type is StageRunType.TRAIN:
self.train(checkpoint=checkpoint, stage=stage_name)
checkpoint = None
Expand Down
27 changes: 26 additions & 1 deletion src/llmcompressor/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import torch
from loguru import logger
from torch.nn import Module
from torch.utils.data import IterableDataset
from torch.utils.data import DataLoader, IterableDataset
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import get_last_checkpoint

from llmcompressor.core import (
active_session,
apply,
callbacks,
create_session,
finalize,
Expand Down Expand Up @@ -431,6 +432,30 @@ def predict(self, *args, **kwargs):

return output

def one_shot(
self, calibration_data: Optional[DataLoader] = None, stage: Optional[str] = None
):
"""
Run oneshot calibration on the active model
:param stage: which stage of the recipe to run, or None to run whole recipe
:param calib_data: dataloader of calibration data
"""
apply(
recipe=self.recipe,
recipe_stage=stage,
recipe_args=self.recipe_args,
model=self.model,
calib_data=calibration_data,
start=-1,
copy_data=False,
accelerator=self.accelerator,
min_tokens_per_module=self.min_tokens_per_module,
)

# log model sparsity
# self.maybe_log_model_sparsification()
self.accelerator.wait_for_everyone()

def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False):
"""
Override of the save_model function and expects it to exist in the parent.
Expand Down
Loading

0 comments on commit 6859adc

Please sign in to comment.