Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement lazy loading for traceable models #1105

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 28 additions & 18 deletions src/llmcompressor/transformers/tracing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
from .llava import (
LlavaForConditionalGeneration as TraceableLlavaForConditionalGeneration,
)
from .mllama import (
MllamaForConditionalGeneration as TraceableMllamaForConditionalGeneration,
)
from .qwen2_vl import (
Qwen2VLForConditionalGeneration as TraceableQwen2VLForConditionalGeneration,
)
from .idefics3 import (
Idefics3ForConditionalGeneration as TraceableIdefics3ForConditionalGeneration
)
from typing import TYPE_CHECKING

__all__ = [
"TraceableLlavaForConditionalGeneration",
"TraceableMllamaForConditionalGeneration",
"TraceableQwen2VLForConditionalGeneration",
"TraceableIdefics3ForConditionalGeneration"
]
import sys
import importlib
from llmcompressor.utils.AliasableLazyModule import _AliasableLazyModule
from transformers.utils.import_utils import define_import_structure

_aliases = {
"TraceableLlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), # noqa: E501
"TraceableMllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
"TraceableQwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
"TraceableIdefics3ForConditionalGeneration": ("idefics3", "Idefics3ForConditionalGeneration"), # noqa: E501
}

if TYPE_CHECKING:
for alias, (module_name, class_name) in _aliases.items():
module = importlib.import_module(f".{module_name}", __package__)
locals()[alias] = getattr(module, class_name)
else:
_file = globals()["__file__"]
sys.modules[__name__] = _AliasableLazyModule(
name=__name__,
module_file=_file,
import_structure=define_import_structure(_file),
module_spec=__spec__,
aliases=_aliases
)

__all__ = list(_aliases.keys())
34 changes: 34 additions & 0 deletions src/llmcompressor/utils/AliasableLazyModule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Any, Dict, Tuple

from transformers.utils import _LazyModule

__all__ = ["_AliasableLazyModule"]


class _AliasableLazyModule(_LazyModule):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't the registry take care of lazy loading?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not if we want to maintain the api of from llmcompressor.transformers.tracer import TraceableX

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, can we do something like

# examples script
from transformers import LlavaForConditionalGeneration
model = LlavaForConditionalGeneration.from_pretrained(...)

oneshot(model=model, ...)

# in the backend map `LlavaForConditionalGeneration` to your `llmcompressor.transformers.tracer` using registry.
# So, ex.,`LlavaForConditionalGeneration` maps to `TraceableLlavaForConditionalGeneration`

This way the user can just use transformers instead of our llmcompressor Model.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model = LlavaForConditionalGeneration.from_pretrained(...)

  1. I'm not really sure what this API is pointing to, as from what you've written, there's no distinction between code which loads the normal vs traceable definitions.
  2. As we spoke about before, we'd need to change upstream code. These changes almost certainly wouldn't be accepted by the transformers team, as it's outside of the responsibilities of the transformers library
  3. What is traceable for LLM Compressor is not what is traceable for other users. This is because there is specialized code in LLM Compressor in order to make tracing easier

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't the registry take care of lazy loading?

We could implement lazy loading using a registry, but this would involve having to add registry code into the traceable definition.

@TracingRegistry.register(name="TraceableLlavaForConditionalGeneration")
class LlavaForConditionalGeneration:
    ...

And imho this makes for a clunkier top level interface

from llmcompressor.transformers.tracing import TracingRegistry

model = TracingRegistry.load_from_registry("LlavaForConditionalGeneration").from_pretrained("path")

Copy link
Collaborator

@horheynm horheynm Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to push code to HF upstream, we just use what in HF right now.
In llm-comp, we just add code mapping to which multi-model do we point to your tracing model.

Very simply something like

{
"HFMODEL": "YOUR_TRACIABLE_MODEL"
"LlavaForConditionalGeneration": "TraceableLlavaForConditionalGeneration"
}

But using the registry.

Here we can just use HF code in the example UX, but in backend would use your traciable model

Copy link
Collaborator Author

@kylesayrs kylesayrs Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In llm-comp, we just add code mapping to which multi-model do we point to your tracing model.
Here we can just use HF code in the example UX, but in backend would use your traciable model

I understand what a registry dictionary is. I don't understand how it is possible to implement a registry in LLM Compressor while only using the HF library at the top level

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're referring to dynamically replacing the model definition within oneshot, I consider that to be an antipattern which makes it unclear to the user what model they're really loading. This also opens up the unintended consequences of loading a model definition twice, such as if the user modifies the model config prior to oneshot

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, traceable definitions are not needed for recipes which do not use the sequential pipeline. Because the traceable definitions include things like processing error checks, they're useful to keep in for most cases to allow the user to better debug their dataloading

"""
Extends _LazyModule to support aliases names

>>> _file = globals()["__file__"]
>>> sys.modules["animals"] = _AliasableLazyModule(
name="animals,
module_file=_file,
import_structure=define_import_structure(_file),
module_spec=__spec__,
aliases={
"PigWithLipstick": ("mammals", "Pig"),
}
>>> from animals import PigWithLipstick
"""

def __init__(self, *args, aliases: Dict[str, Tuple[str, str]], **kwargs):
super().__init__(*args, **kwargs)
self._aliases = aliases

def __getattr__(self, name: str) -> Any:
if name in self._aliases:
module_name, name = self._aliases[name]
module = self._get_module(module_name)
return getattr(module, name)

return super().__getattr__(name)
55 changes: 55 additions & 0 deletions tests/llmcompressor/transformers/tracing/test_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import importlib
import sys
from types import ModuleType
from unittest.mock import patch

import pytest


@pytest.fixture(autouse=True)
def clean_imports():
# Remove any existing imports before each test
module_names = list(sys.modules.keys())
for module_name in module_names:
if module_name.startswith("llmcompressor"):
del sys.modules[module_name]

importlib.invalidate_caches()

yield


def test_lazy_loading(clean_imports):
# mock import_module
imported_module_names = []
original_import_module = importlib.import_module

def mock_import_module(name, *args, **kwargs):
nonlocal imported_module_names
imported_module_names.append(name)
return original_import_module(name, *args, **kwargs)

# import with alias
with patch("importlib.import_module", mock_import_module):
from llmcompressor.transformers.tracing import ( # noqa: F401
TraceableLlavaForConditionalGeneration,
)

# test that llava was imported by mllama was not
assert ".llava" in imported_module_names
assert ".mllama" not in imported_module_names

# test that tracing module has a llava attribute but not an mllama attribute
attributes = ModuleType.__dir__(sys.modules["llmcompressor.transformers.tracing"])
assert "llava" in attributes
assert "mllama" not in attributes


def test_class_names(clean_imports):
import llmcompressor.transformers.tracing as TracingModule

# test that the class names are not the aliased names
# this is important for correctly saving model configs
for cls_alias, (_loc, cls_name) in TracingModule._aliases.items():
cls = getattr(TracingModule, cls_alias)
assert cls.__name__ == cls_name