Skip to content

Commit

Permalink
[Core][VLM] Test registration for OOT multimodal models (#8717)
Browse files Browse the repository at this point in the history
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
  • Loading branch information
ywang96 and DarkLight1337 authored Oct 4, 2024
1 parent e5dc713 commit 26aa325
Show file tree
Hide file tree
Showing 12 changed files with 227 additions and 49 deletions.
18 changes: 15 additions & 3 deletions docs/source/models/adding_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,16 @@ When it comes to the linear layers, we provide the following options to parallel
* :code:`ReplicatedLinear`: Replicates the inputs and weights across multiple GPUs. No memory saving.
* :code:`RowParallelLinear`: The input tensor is partitioned along the hidden dimension. The weight matrix is partitioned along the rows (input dimension). An *all-reduce* operation is performed after the matrix multiplication to reduce the results. Typically used for the second FFN layer and the output linear transformation of the attention layer.
* :code:`ColumnParallelLinear`: The input tensor is replicated. The weight matrix is partitioned along the columns (output dimension). The result is partitioned along the column dimension. Typically used for the first FFN layer and the separated QKV transformation of the attention layer in the original Transformer.
* :code:`MergedColumnParallelLinear`: Column-parallel linear that merges multiple `ColumnParallelLinear` operators. Typically used for the first FFN layer with weighted activation functions (e.g., SiLU). This class handles the sharded weight loading logic of multiple weight matrices.
* :code:`MergedColumnParallelLinear`: Column-parallel linear that merges multiple :code:`ColumnParallelLinear` operators. Typically used for the first FFN layer with weighted activation functions (e.g., SiLU). This class handles the sharded weight loading logic of multiple weight matrices.
* :code:`QKVParallelLinear`: Parallel linear layer for the query, key, and value projections of the multi-head and grouped-query attention mechanisms. When number of key/value heads are less than the world size, this class replicates the key/value heads properly. This class handles the weight loading and replication of the weight matrices.

Note that all the linear layers above take `linear_method` as an input. vLLM will set this parameter according to different quantization schemes to support weight quantization.
Note that all the linear layers above take :code:`linear_method` as an input. vLLM will set this parameter according to different quantization schemes to support weight quantization.

4. Implement the weight loading logic
-------------------------------------

You now need to implement the :code:`load_weights` method in your :code:`*ForCausalLM` class.
This method should load the weights from the HuggingFace's checkpoint file and assign them to the corresponding layers in your model. Specifically, for `MergedColumnParallelLinear` and `QKVParallelLinear` layers, if the original model has separated weight matrices, you need to load the different parts separately.
This method should load the weights from the HuggingFace's checkpoint file and assign them to the corresponding layers in your model. Specifically, for :code:`MergedColumnParallelLinear` and :code:`QKVParallelLinear` layers, if the original model has separated weight matrices, you need to load the different parts separately.

5. Register your model
----------------------
Expand All @@ -114,6 +114,18 @@ Just add the following lines in your code:
from your_code import YourModelForCausalLM
ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM)
If your model imports modules that initialize CUDA, consider instead lazy-importing it to avoid an error like :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`:

.. code-block:: python
from vllm import ModelRegistry
ModelRegistry.register_model("YourModelForCausalLM", "your_code:YourModelForCausalLM")
.. important::
If your model is a multimodal model, make sure the model class implements the :class:`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
Read more about that :ref:`here <enabling_multimodal_inputs>`.

If you are running api server with :code:`vllm serve <args>`, you can wrap the entrypoint with the following code:

.. code-block:: python
Expand Down
33 changes: 33 additions & 0 deletions find_cuda_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import importlib
import traceback
from typing import Callable
from unittest.mock import patch


def find_cuda_init(fn: Callable[[], object]) -> None:
"""
Helper function to debug CUDA re-initialization errors.
If `fn` initializes CUDA, prints the stack trace of how this happens.
"""
from torch.cuda import _lazy_init

stack = None

def wrapper():
nonlocal stack
stack = traceback.extract_stack()
return _lazy_init()

with patch("torch.cuda._lazy_init", wrapper):
fn()

if stack is not None:
print("==== CUDA Initialized ====")
print("".join(traceback.format_list(stack)).strip())
print("==========================")


if __name__ == "__main__":
find_cuda_init(
lambda: importlib.import_module("vllm.model_executor.models.llava"))
30 changes: 25 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,15 +879,16 @@ def num_gpus_available():


temp_dir = tempfile.gettempdir()
_dummy_path = os.path.join(temp_dir, "dummy_opt")
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")


@pytest.fixture
def dummy_opt_path():
json_path = os.path.join(_dummy_path, "config.json")
if not os.path.exists(_dummy_path):
json_path = os.path.join(_dummy_opt_path, "config.json")
if not os.path.exists(_dummy_opt_path):
snapshot_download(repo_id="facebook/opt-125m",
local_dir=_dummy_path,
local_dir=_dummy_opt_path,
ignore_patterns=[
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
"*.msgpack"
Expand All @@ -898,4 +899,23 @@ def dummy_opt_path():
config["architectures"] = ["MyOPTForCausalLM"]
with open(json_path, "w") as f:
json.dump(config, f)
return _dummy_path
return _dummy_opt_path


@pytest.fixture
def dummy_llava_path():
json_path = os.path.join(_dummy_llava_path, "config.json")
if not os.path.exists(_dummy_llava_path):
snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf",
local_dir=_dummy_llava_path,
ignore_patterns=[
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
"*.msgpack"
])
assert os.path.exists(json_path)
with open(json_path, "r") as f:
config = json.load(f)
config["architectures"] = ["MyLlava"]
with open(json_path, "w") as f:
json.dump(config, f)
return _dummy_llava_path
4 changes: 3 additions & 1 deletion tests/entrypoints/openai/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def server():
"--dtype",
"bfloat16",
"--max-model-len",
"4096",
"2048",
"--max-num-seqs",
"5",
"--enforce-eager",
]

Expand Down
13 changes: 10 additions & 3 deletions tests/entrypoints/openai/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,16 @@
@pytest.fixture(scope="module")
def server():
args = [
"--dtype", "bfloat16", "--max-model-len", "4096", "--max-num-seqs",
"5", "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt",
f"image={MAXIMUM_IMAGES}"
"--dtype",
"bfloat16",
"--max-model-len",
"2048",
"--max-num-seqs",
"5",
"--enforce-eager",
"--trust-remote-code",
"--limit-mm-per-prompt",
f"image={MAXIMUM_IMAGES}",
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
Expand Down
38 changes: 38 additions & 0 deletions tests/models/test_oot_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset

from ..utils import fork_new_process_for_each_test

Expand All @@ -29,3 +30,40 @@ def test_oot_registration(dummy_opt_path):
# make sure only the first token is generated
rest = generated_text.replace(first_token, "")
assert rest == ""


image = ImageAsset("cherry_blossom").pil_image.convert("RGB")


@fork_new_process_for_each_test
def test_oot_multimodal_registration(dummy_llava_path):
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
prompts = [{
"prompt": "What's in the image?<image>",
"multi_modal_data": {
"image": image
},
}, {
"prompt": "Describe the image<image>",
"multi_modal_data": {
"image": image
},
}]

sampling_params = SamplingParams(temperature=0)
llm = LLM(model=dummy_llava_path,
load_format="dummy",
max_num_seqs=1,
trust_remote_code=True,
gpu_memory_utilization=0.98,
max_model_len=4096,
enforce_eager=True,
limit_mm_per_prompt={"image": 1})
first_token = llm.get_tokenizer().decode(0)
outputs = llm.generate(prompts, sampling_params)

for output in outputs:
generated_text = output.outputs[0].text
# make sure only the first token is generated
rest = generated_text.replace(first_token, "")
assert rest == ""
Original file line number Diff line number Diff line change
@@ -1,26 +1,14 @@
from typing import Optional

import torch

from vllm import ModelRegistry
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata


class MyOPTForCausalLM(OPTForCausalLM):

def compute_logits(
self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
if logits is not None:
logits.zero_()
logits[:, 0] += 1.0
return logits


def register():
# register our dummy model
# Test directly passing the model
from .my_opt import MyOPTForCausalLM

if "MyOPTForCausalLM" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model("MyOPTForCausalLM", MyOPTForCausalLM)

# Test passing lazy model
if "MyLlava" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model("MyLlava",
"vllm_add_dummy_model.my_llava:MyLlava")
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Optional

import torch

from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
dummy_data_for_llava,
get_max_llava_image_tokens,
input_processor_for_llava)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY


@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
class MyLlava(LlavaForConditionalGeneration):

def compute_logits(
self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
if logits is not None:
logits.zero_()
logits[:, 0] += 1.0
return logits
19 changes: 19 additions & 0 deletions tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Optional

import torch

from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata


class MyOPTForCausalLM(OPTForCausalLM):

def compute_logits(
self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
if logits is not None:
logits.zero_()
logits[:, 0] += 1.0
return logits
2 changes: 2 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ class EngineArgs:
def __post_init__(self):
if self.tokenizer is None:
self.tokenizer = self.model
from vllm.plugins import load_general_plugins
load_general_plugins()

@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Expand Down
3 changes: 0 additions & 3 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,6 @@ def __init__(
model_config.mm_processor_kwargs,
)
# TODO(woosuk): Print more configs in debug mode.
from vllm.plugins import load_general_plugins
load_general_plugins()

self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
Expand Down
Loading

0 comments on commit 26aa325

Please sign in to comment.