-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement lazy loading for traceable models #1105
Open
kylesayrs
wants to merge
8
commits into
main
Choose a base branch
from
kylesayrs/lazy-tracing-import
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 5 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
f03037e
implement lazy loading for traceable models
kylesayrs f3be499
wip
kylesayrs ab052d5
extend to support import with aliases
kylesayrs 6862da6
clean tests
kylesayrs 5d5dbd4
rename file
kylesayrs d4a36a4
fix typo
kylesayrs df31d83
add marks
kylesayrs b9b3248
Merge remote-tracking branch 'origin' into kylesayrs/lazy-tracing-import
kylesayrs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,29 @@ | ||
from .llava import ( | ||
LlavaForConditionalGeneration as TraceableLlavaForConditionalGeneration, | ||
) | ||
from .mllama import ( | ||
MllamaForConditionalGeneration as TraceableMllamaForConditionalGeneration, | ||
) | ||
from .qwen2_vl import ( | ||
Qwen2VLForConditionalGeneration as TraceableQwen2VLForConditionalGeneration, | ||
) | ||
from .idefics3 import ( | ||
Idefics3ForConditionalGeneration as TraceableIdefics3ForConditionalGeneration | ||
) | ||
from typing import TYPE_CHECKING | ||
|
||
__all__ = [ | ||
"TraceableLlavaForConditionalGeneration", | ||
"TraceableMllamaForConditionalGeneration", | ||
"TraceableQwen2VLForConditionalGeneration", | ||
"TraceableIdefics3ForConditionalGeneration" | ||
] | ||
import sys | ||
import importlib | ||
from llmcompressor.utils.AliasableLazyModule import _AliasableLazyModule | ||
from transformers.utils.import_utils import define_import_structure | ||
|
||
_aliases = { | ||
"TraceableLlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), # noqa: E501 | ||
"TraceableMllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 | ||
"TraceableQwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 | ||
"TraceableIdefics3ForConditionalGeneration": ("idefics3", "Idefics3ForConditionalGeneration"), # noqa: E501 | ||
} | ||
|
||
if TYPE_CHECKING: | ||
for alias, (module_name, class_name) in _aliases.items(): | ||
module = importlib.import_module(f".{module_name}", __package__) | ||
locals()[alias] = getattr(module, class_name) | ||
else: | ||
_file = globals()["__file__"] | ||
sys.modules[__name__] = _AliasableLazyModule( | ||
name=__name__, | ||
module_file=_file, | ||
import_structure=define_import_structure(_file), | ||
module_spec=__spec__, | ||
aliases=_aliases | ||
) | ||
|
||
__all__ = list(_aliases.keys()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from typing import Any, Dict, Tuple | ||
|
||
from transformers.utils import _LazyModule | ||
|
||
__all__ = ["_AliasableLazyModule"] | ||
|
||
|
||
class _AliasableLazyModule(_LazyModule): | ||
""" | ||
Extends _LazyModule to support aliases names | ||
|
||
>>> _file = globals()["__file__"] | ||
>>> sys.modules["animals"] = _AliasableLazyModule( | ||
name="animals, | ||
module_file=_file, | ||
import_structure=define_import_structure(_file), | ||
module_spec=__spec__, | ||
aliases={ | ||
"PigWithLipstick": ("mammals", "Pig"), | ||
} | ||
>>> from animals import PigWithLipstick | ||
""" | ||
|
||
def __init__(self, *args, aliases: Dict[str, Tuple[str, str]], **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self._aliases = aliases | ||
|
||
def __getattr__(self, name: str) -> Any: | ||
if name in self._aliases: | ||
module_name, name = self._aliases[name] | ||
module = self._get_module(module_name) | ||
return getattr(module, name) | ||
|
||
return super().__getattr__(name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import importlib | ||
import sys | ||
from types import ModuleType | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def clean_imports(): | ||
# Remove any existing imports before each test | ||
module_names = list(sys.modules.keys()) | ||
for module_name in module_names: | ||
if module_name.startswith("llmcompressor"): | ||
del sys.modules[module_name] | ||
|
||
importlib.invalidate_caches() | ||
|
||
yield | ||
|
||
|
||
def test_lazy_loading(clean_imports): | ||
# mock import_module | ||
imported_module_names = [] | ||
original_import_module = importlib.import_module | ||
|
||
def mock_import_module(name, *args, **kwargs): | ||
nonlocal imported_module_names | ||
imported_module_names.append(name) | ||
return original_import_module(name, *args, **kwargs) | ||
|
||
# import with alias | ||
with patch("importlib.import_module", mock_import_module): | ||
from llmcompressor.transformers.tracing import ( # noqa: F401 | ||
TraceableLlavaForConditionalGeneration, | ||
) | ||
|
||
# test that llava was imported by mllama was not | ||
assert ".llava" in imported_module_names | ||
assert ".mllama" not in imported_module_names | ||
|
||
# test that tracing module has a llava attribute but not an mllama attribute | ||
attributes = ModuleType.__dir__(sys.modules["llmcompressor.transformers.tracing"]) | ||
assert "llava" in attributes | ||
assert "mllama" not in attributes | ||
|
||
|
||
def test_class_names(clean_imports): | ||
import llmcompressor.transformers.tracing as TracingModule | ||
|
||
# test that the class names are not the aliased names | ||
# this is important for correctly saving model configs | ||
for cls_alias, (_loc, cls_name) in TracingModule._aliases.items(): | ||
cls = getattr(TracingModule, cls_alias) | ||
assert cls.__name__ == cls_name |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't the registry take care of lazy loading?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not if we want to maintain the api of
from llmcompressor.transformers.tracer import TraceableX
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, can we do something like
This way the user can just use transformers instead of our llmcompressor Model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
transformers
libraryThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could implement lazy loading using a registry, but this would involve having to add registry code into the traceable definition.
And imho this makes for a clunkier top level interface
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to push code to HF upstream, we just use what in HF right now.
In llm-comp, we just add code mapping to which multi-model do we point to your tracing model.
Very simply something like
But using the registry.
Here we can just use HF code in the example UX, but in backend would use your traciable model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand what a registry dictionary is. I don't understand how it is possible to implement a registry in LLM Compressor while only using the HF library at the top level
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you're referring to dynamically replacing the model definition within
oneshot
, I consider that to be an antipattern which makes it unclear to the user what model they're really loading. This also opens up the unintended consequences of loading a model definition twice, such as if the user modifies the model config prior tooneshot
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, traceable definitions are not needed for recipes which do not use the sequential pipeline. Because the traceable definitions include things like processing error checks, they're useful to keep in for most cases to allow the user to better debug their dataloading