Skip to content

Commit

Permalink
rename get_shared_processor_src to get_processor_from_model
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Jan 28, 2025
1 parent 6fa5a5e commit 9ea2bcd
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
10 changes: 6 additions & 4 deletions src/llmcompressor/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import warnings
from pathlib import PosixPath
from typing import Optional

from loguru import logger
from transformers import (
Expand Down Expand Up @@ -51,7 +52,7 @@
patch_tied_tensors_bug,
)
from llmcompressor.transformers.sparsification.sparse_model import (
get_shared_processor_src,
get_processor_from_model,
)
from llmcompressor.transformers.utils.helpers import (
detect_last_checkpoint,
Expand Down Expand Up @@ -257,10 +258,11 @@ def initialize_model_from_path(


def initialize_processor_from_path(
model_args: ModelArguments, model: PreTrainedModel, teacher: PreTrainedModel
model_args: ModelArguments,
model: PreTrainedModel,
teacher: Optional[PreTrainedModel] = None,
) -> Processor:
processor_src = model_args.processor
processor_src = processor_src or get_shared_processor_src(model, teacher)
processor_src = model_args.processor or get_processor_from_model(model, teacher)
# The use_fast=True option is not currently supported safely in Transformers
# See: https://github.com/huggingface/transformers/pull/34836#issuecomment-2491809727 # noqa: E501
try:
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/sparsification/sparse_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

__all__ = [
"SparseAutoModelForCausalLM",
"get_shared_processor_src",
"get_processor_from_model",
]


Expand All @@ -20,7 +20,7 @@ def from_pretrained(*args, **kwargs):
return AutoModelForCausalLM.from_pretrained(*args, **kwargs)


def get_shared_processor_src(student: Module, teacher: Optional[Module]) -> str:
def get_processor_from_model(student: Module, teacher: Optional[Module]) -> str:
"""
Get a processor/tokenizer source used for both student and teacher, assuming
that they could be shared
Expand Down

0 comments on commit 9ea2bcd

Please sign in to comment.