diff --git a/src/llmcompressor/transformers/tracing/__init__.py b/src/llmcompressor/transformers/tracing/__init__.py index 39410a1ef..069bd43cf 100644 --- a/src/llmcompressor/transformers/tracing/__init__.py +++ b/src/llmcompressor/transformers/tracing/__init__.py @@ -1,23 +1,30 @@ -from .llava import ( - LlavaForConditionalGeneration as TraceableLlavaForConditionalGeneration, -) -from .mllama import ( - MllamaForConditionalGeneration as TraceableMllamaForConditionalGeneration, -) -from .qwen2_vl import ( - Qwen2VLForConditionalGeneration as TraceableQwen2VLForConditionalGeneration, -) -from .idefics3 import ( - Idefics3ForConditionalGeneration as TraceableIdefics3ForConditionalGeneration -) -from .whisper import ( - WhisperForConditionalGeneration as TraceableWhisperForConditionalGeneration -) +from typing import TYPE_CHECKING -__all__ = [ - "TraceableLlavaForConditionalGeneration", - "TraceableMllamaForConditionalGeneration", - "TraceableQwen2VLForConditionalGeneration", - "TraceableIdefics3ForConditionalGeneration", - "TraceableWhisperForConditionalGeneration", -] +import sys +import importlib +from llmcompressor.utils.AliasableLazyModule import _AliasableLazyModule +from transformers.utils.import_utils import define_import_structure + +_aliases = { + "TraceableLlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), # noqa: E501 + "TraceableMllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 + "TraceableQwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 + "TraceableIdefics3ForConditionalGeneration": ("idefics3", "Idefics3ForConditionalGeneration"), # noqa: E501 + "TraceableWhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration") # noqa: E501 +} + +if TYPE_CHECKING: + for alias, (module_name, class_name) in _aliases.items(): + module = importlib.import_module(f".{module_name}", __package__) + locals()[alias] = getattr(module, class_name) +else: + _file = globals()["__file__"] + sys.modules[__name__] = _AliasableLazyModule( + name=__name__, + module_file=_file, + import_structure=define_import_structure(_file), + module_spec=__spec__, + aliases=_aliases + ) + +__all__ = list(_aliases.keys()) diff --git a/src/llmcompressor/utils/AliasableLazyModule.py b/src/llmcompressor/utils/AliasableLazyModule.py new file mode 100644 index 000000000..44d85fc49 --- /dev/null +++ b/src/llmcompressor/utils/AliasableLazyModule.py @@ -0,0 +1,34 @@ +from typing import Any, Dict, Tuple + +from transformers.utils import _LazyModule + +__all__ = ["_AliasableLazyModule"] + + +class _AliasableLazyModule(_LazyModule): + """ + Extends _LazyModule to support aliases names + + >>> _file = globals()["__file__"] + >>> sys.modules["animals"] = _AliasableLazyModule( + name="animals, + module_file=_file, + import_structure=define_import_structure(_file), + module_spec=__spec__, + aliases={ + "PigWithLipstick": ("mammals", "Pig"), + } + >>> from animals import PigWithLipstick + """ + + def __init__(self, *args, aliases: Dict[str, Tuple[str, str]], **kwargs): + super().__init__(*args, **kwargs) + self._aliases = aliases + + def __getattr__(self, name: str) -> Any: + if name in self._aliases: + module_name, name = self._aliases[name] + module = self._get_module(module_name) + return getattr(module, name) + + return super().__getattr__(name) diff --git a/tests/llmcompressor/transformers/tracing/test_init.py b/tests/llmcompressor/transformers/tracing/test_init.py new file mode 100644 index 000000000..1f75f98c3 --- /dev/null +++ b/tests/llmcompressor/transformers/tracing/test_init.py @@ -0,0 +1,57 @@ +import importlib +import sys +from types import ModuleType +from unittest.mock import patch + +import pytest + + +@pytest.fixture(autouse=True) +def clean_imports(): + # Remove any existing imports before each test + module_names = list(sys.modules.keys()) + for module_name in module_names: + if module_name.startswith("llmcompressor"): + del sys.modules[module_name] + + importlib.invalidate_caches() + + yield + + +@pytest.mark.unit +def test_lazy_loading(clean_imports): + # mock import_module + imported_module_names = [] + original_import_module = importlib.import_module + + def mock_import_module(name, *args, **kwargs): + nonlocal imported_module_names + imported_module_names.append(name) + return original_import_module(name, *args, **kwargs) + + # import with alias + with patch("importlib.import_module", mock_import_module): + from llmcompressor.transformers.tracing import ( # noqa: F401 + TraceableLlavaForConditionalGeneration, + ) + + # test that llava was imported and mllama was not + assert ".llava" in imported_module_names + assert ".mllama" not in imported_module_names + + # test that tracing module has a llava attribute but not an mllama attribute + attributes = ModuleType.__dir__(sys.modules["llmcompressor.transformers.tracing"]) + assert "llava" in attributes + assert "mllama" not in attributes + + +@pytest.mark.unit +def test_class_names(clean_imports): + import llmcompressor.transformers.tracing as TracingModule + + # test that the class names are not the aliased names + # this is important for correctly saving model configs + for cls_alias, (_loc, cls_name) in TracingModule._aliases.items(): + cls = getattr(TracingModule, cls_alias) + assert cls.__name__ == cls_name