Skip to content

Commit

Permalink
pass /tests/llmcompressor/transformers/finetune/test_finetune_no_reci…
Browse files Browse the repository at this point in the history
…pe_custom_dataset.py
  • Loading branch information
horheynm committed Jan 15, 2025
1 parent 7e84319 commit e53ece6
Show file tree
Hide file tree
Showing 10 changed files with 662 additions and 51 deletions.
3 changes: 1 addition & 2 deletions src/llmcompressor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,9 @@
"LoggerConfig",
]

from llmcompressor.core.session_functions import (
from llmcompressor.core.session_functions import ( # callbacks,
active_session,
apply,
callbacks,
create_session,
finalize,
initialize,
Expand Down
5 changes: 2 additions & 3 deletions src/llmcompressor/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
from llmcompressor.core.lifecycle import CompressionLifecycle
from llmcompressor.core.model_layer import ModelParameterizedLayer
from llmcompressor.core.session import CompressionSession
from llmcompressor.core.session_functions import (
from llmcompressor.core.session_functions import ( # callbacks,
LifecycleCallbacks,
active_session,
apply,
callbacks,
create_session,
finalize,
initialize,
Expand Down Expand Up @@ -41,6 +40,6 @@
"initialize",
"finalize",
"apply",
"callbacks",
# "callbacks",
"LifecycleCallbacks",
]
4 changes: 2 additions & 2 deletions src/llmcompressor/core/session_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"initialize",
"finalize",
"apply",
"callbacks",
# "callbacks",
"LifecycleCallbacks",
]

Expand Down Expand Up @@ -281,4 +281,4 @@ def batch_end(cls, **kwargs) -> ModifiedState:
return cls.event(EventType.BATCH_END, **kwargs)


callbacks = LifecycleCallbacks
# callbacks = LifecycleCallbacks
24 changes: 14 additions & 10 deletions src/llmcompressor/transformers/finetune/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from transformers.trainer_callback import TrainerState

from llmcompressor.core import active_session
from llmcompressor.core import callbacks as session_callbacks

# from llmcompressor.core import callbacks as self.callbacks

__all__ = [
"DisableHalfPrecisionCallback",
Expand All @@ -23,9 +24,10 @@ class TrainingLoopCallbacks(TrainerCallback):
:param kwargs: key word arguments to be passed to base TrainerCallback
"""

def __init__(self, trainer, *args, **kwargs):
def __init__(self, trainer, callbacks, *args, **kwargs):
super().__init__(*args, **kwargs)
self.trainer = trainer
self.callbacks = callbacks

def on_train_begin(
self,
Expand All @@ -39,8 +41,9 @@ def on_train_begin(
model, as it will have changed to a wrapper if FSDP is enabled
"""
super().on_train_begin(args, state, control, **kwargs)
session = active_session()
session.state.model = self.trainer.model
# session = active_session()
# session.state.model = self.trainer.model
self.trainer.lifecycle.state.model = self.trainer.model

def on_step_end(
self,
Expand All @@ -56,8 +59,8 @@ def on_step_end(
Triggers optimizer post_step and batch_end in the active CompressionSession
"""
super().on_step_end(args, state, control, **kwargs)
session_callbacks.optim_post_step()
session_callbacks.batch_end()
self.callbacks.optim_post_step()
self.callbacks.batch_end()

def on_substep_end(
self,
Expand All @@ -72,8 +75,8 @@ def on_substep_end(
Triggers optimizer post_step and batch_end in the active CompressionSession
"""
super().on_substep_end(args, state, control, **kwargs)
session_callbacks.optim_post_step()
session_callbacks.batch_end()
self.callbacks.optim_post_step()
self.callbacks.batch_end()


class DisableHalfPrecisionCallback(TrainerCallback):
Expand All @@ -95,8 +98,9 @@ def qat_active(self) -> bool:
"""
:return: True if a quantization modifier is active in the current session
"""
session = active_session()
return session.state.model.qat_active()
# session = active_session()
# return session.state.model.qat_active()
return self.trainer.lifecycle.state.model.qat_active()

def on_epoch_begin(
self,
Expand Down
19 changes: 12 additions & 7 deletions src/llmcompressor/transformers/finetune/data/data_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def get_calibration_dataloader(
processor,
add_labels: bool = False, # for oneshot
do_oneshot=True,
do_train=False,
):
"""
Loads datasets for each flow based on data_args, stores a Dataset for each
Expand Down Expand Up @@ -309,13 +310,17 @@ def _get_split_name(inp_str):
datasets = make_dataset_splits(
tokenized_datasets,
do_oneshot=do_oneshot,
do_train=do_train,
)

calibration_dataset = datasets.get("calibration")
if do_oneshot:
calibration_dataset = datasets.get("calibration")

return format_calibration_data(
tokenized_dataset=calibration_dataset,
num_calibration_samples=data_args.num_calibration_samples,
do_shuffle=data_args.shuffle_calibration_samples,
collate_fn=data_args.data_collator,
)
return format_calibration_data(
tokenized_dataset=calibration_dataset,
num_calibration_samples=data_args.num_calibration_samples,
do_shuffle=data_args.shuffle_calibration_samples,
collate_fn=data_args.data_collator,
)
if do_train:
return datasets.get("train"), datasets.get("validation")
3 changes: 1 addition & 2 deletions src/llmcompressor/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import get_last_checkpoint

from llmcompressor.core import (
from llmcompressor.core import ( # callbacks,
active_session,
apply,
callbacks,
create_session,
finalize,
initialize,
Expand Down
12 changes: 8 additions & 4 deletions src/llmcompressor/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,17 @@


def train(**kwargs):
from llmcompressor.transformers.train.train import Train

"""
CLI entrypoint for running training
"""
model_args, data_args, recipe_args, training_args = parse_args(**kwargs)
training_args.do_train = True
main(model_args, data_args, recipe_args, training_args)
# model_args, data_args, recipe_args, training_args = parse_args(**kwargs)
# training_args.do_train = True
# main(model_args, data_args, recipe_args, training_args)
trainer = Train(**kwargs)
trainer.run()
return trainer


def eval(**kwargs):
Expand Down Expand Up @@ -261,7 +266,6 @@ def initialize_model_from_path(
if teacher is not None and "sequence_length" in teacher_kwargs:
teacher.seqlen = teacher_kwargs["sequence_length"]

# return teacher, model_path, model
return model, teacher


Expand Down
Empty file.
Loading

0 comments on commit e53ece6

Please sign in to comment.