Skip to content

Commit

Permalink
[Bugfix] Fix weight loading for Chameleon when TP>1 (vllm-project#7410)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored Aug 13, 2024
1 parent 5469146 commit 7025b11
Show file tree
Hide file tree
Showing 59 changed files with 414 additions and 205 deletions.
26 changes: 18 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import sys
from collections import UserList
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union
from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict,
TypeVar, Union)

import pytest
import torch
Expand All @@ -27,7 +28,7 @@
from vllm.outputs import RequestOutput
from vllm.sequence import SampleLogprobs
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
is_cpu)
identity, is_cpu)

logger = init_logger(__name__)

Expand Down Expand Up @@ -197,6 +198,8 @@ def __init__(
is_embedding_model: bool = False,
is_vision_model: bool = False,
is_encoder_decoder_model: bool = False,
postprocess_inputs: Callable[[BatchEncoding],
BatchEncoding] = identity,
) -> None:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]

Expand Down Expand Up @@ -242,12 +245,14 @@ def __init__(
torch_dtype=torch_dtype,
trust_remote_code=True,
)
except Exception:
except Exception as exc:
logger.warning(
"Unable to auto-load processor from HuggingFace for "
"model %s. Using tokenizer instead.", model_name)
"Unable to auto-load HuggingFace processor for model (%s). "
"Using tokenizer instead. Reason: %s", model_name, exc)
self.processor = self.tokenizer

self.postprocess_inputs = postprocess_inputs

def generate(
self,
prompts: List[str],
Expand All @@ -267,6 +272,7 @@ def generate(
processor_kwargs["images"] = images[i]

inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)

output_ids = self.model.generate(
**self.wrap_device(inputs),
Expand Down Expand Up @@ -336,6 +342,7 @@ def generate_greedy_logprobs(
processor_kwargs["images"] = images[i]

inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)

output = self.model.generate(
**self.wrap_device(inputs),
Expand Down Expand Up @@ -420,6 +427,7 @@ def generate_greedy_logprobs_limit(
processor_kwargs["images"] = images[i]

inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)

output = self.model.generate(
**self.wrap_device(inputs),
Expand Down Expand Up @@ -552,7 +560,8 @@ def generate(
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[List[Image.Image]] = None,
images: Optional[Union[List[Image.Image],
List[List[Image.Image]]]] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
if images is not None:
assert len(prompts) == len(images)
Expand Down Expand Up @@ -587,7 +596,7 @@ def _final_steps_generate_w_logprobs(
for req_output in req_outputs:
for sample in req_output.outputs:
output_str = sample.text
output_ids = sample.token_ids
output_ids = list(sample.token_ids)
output_logprobs = sample.logprobs
outputs.append((output_ids, output_str, output_logprobs))
return outputs
Expand All @@ -596,7 +605,8 @@ def generate_w_logprobs(
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[List[Image.Image]] = None,
images: Optional[Union[List[Image.Image],
List[List[Image.Image]]]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
assert sampling_params.logprobs is not None

Expand Down
4 changes: 4 additions & 0 deletions tests/distributed/test_multimodal_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
@pytest.mark.parametrize("model, distributed_executor_backend", [
("llava-hf/llava-1.5-7b-hf", "ray"),
("llava-hf/llava-v1.6-mistral-7b-hf", "ray"),
("facebook/chameleon-7b", "ray"),
("llava-hf/llava-1.5-7b-hf", "mp"),
("llava-hf/llava-v1.6-mistral-7b-hf", "mp"),
("facebook/chameleon-7b", "mp"),
])
@fork_new_process_for_each_test
def test_models(hf_runner, vllm_runner, image_assets, model: str,
Expand All @@ -34,6 +36,8 @@ def test_models(hf_runner, vllm_runner, image_assets, model: str,
from ..models.test_llava import models, run_test
elif model.startswith("llava-hf/llava-v1.6"):
from ..models.test_llava_next import models, run_test
elif model.startswith("facebook/chameleon"):
from ..models.test_chameleon import models, run_test
else:
raise NotImplementedError(f"Unsupported model: {model}")

Expand Down
8 changes: 6 additions & 2 deletions tests/entrypoints/openai/test_oot_registration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import time
from typing import Optional

import torch
from openai import OpenAI, OpenAIError
Expand All @@ -17,8 +18,11 @@

class MyOPTForCausalLM(OPTForCausalLM):

def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
logits.zero_()
Expand Down
91 changes: 58 additions & 33 deletions tests/models/test_chameleon.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import re
from typing import List, Optional, Type

import pytest
from transformers import BatchEncoding

from vllm.multimodal.utils import rescale_image_size
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE

from ..conftest import IMAGE_ASSETS, VllmRunner, _ImageAssets
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_outputs_equal

pytestmark = pytest.mark.vlm

Expand All @@ -19,23 +21,29 @@
models = ["facebook/chameleon-7b"]


#TODO (ywang96): Add correctness test when chameleon is
# available on transformers.
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Test if the model can generate text given
a batch of images and prompts.
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
images = [asset.pil_image for asset in image_assets]

inputs_per_image = [(
Expand All @@ -50,35 +58,49 @@ def run_test(
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:

for prompts, images in inputs_per_image:
vllm_outputs = vllm_model.generate_greedy(prompts,
max_tokens,
images=images)
for i in range(len(vllm_outputs)):

# format prompt back to original
replacements = {
"<racm3:break>": "",
"<eoss>": "",
"<reserved08706>": ""
}
pattern = '|'.join(replacements.keys())
vllm_result = re.sub(
pattern,
lambda match: replacements[match.group(0)], #noqa B023
vllm_outputs[i][1])
vllm_result = vllm_result.replace("<image>", "", 1023)
assert vllm_result[:len(prompts[i])] == prompts[i]

# assert at least 10 new characters are generated
# (to take stop token into account)
assert len(vllm_outputs[i][1]) - len(prompts[i]) > 10
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]

def process(hf_inputs: BatchEncoding):
hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \
.to(torch_dtype) # type: ignore
return hf_inputs

with hf_runner(model,
dtype=dtype,
postprocess_inputs=process,
is_vision_model=True) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]

for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
# HF Logprobs include image tokens, unlike vLLM, so we don't directly
# compare them
check_outputs_equal(
outputs_0_lst=[outputs[:2] for outputs in hf_outputs],
outputs_1_lst=[outputs[:2] for outputs in vllm_outputs],
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
Expand All @@ -88,15 +110,18 @@ def run_test(
],
)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(vllm_runner, image_assets, model, size_factors, dtype: str,
max_tokens: int) -> None:
@pytest.mark.parametrize("max_tokens", [8])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
dtype, max_tokens, num_logprobs) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
21 changes: 13 additions & 8 deletions tests/models/test_llava.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Optional, Tuple, Type

import pytest
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoConfig, AutoTokenizer, BatchEncoding

from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
Expand Down Expand Up @@ -110,16 +110,21 @@ def run_test(
for prompts, images in inputs_per_image
]

with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
if mantis_processor is not None:
if mantis_processor is not None:

def process(*args, **kwargs):
output = mantis_processor(*args, **kwargs)
output["pixel_values"] = output["pixel_values"].to(torch_dtype)
return output
def process(hf_inputs: BatchEncoding):
hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \
.to(torch_dtype) # type: ignore
return hf_inputs
else:

hf_model.processor = process
def process(hf_inputs: BatchEncoding):
return hf_inputs

with hf_runner(model,
dtype=dtype,
postprocess_inputs=process,
is_vision_model=True) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
Expand Down
33 changes: 9 additions & 24 deletions tests/models/test_minicpmv.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from collections import UserDict
from typing import List, Optional, Tuple, Type

import pytest
import torch
import torch.types
from transformers import BatchFeature
from transformers import BatchEncoding

from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
Expand All @@ -14,18 +13,6 @@

pytestmark = pytest.mark.vlm


class NestedInputs(UserDict):

def __init__(self, model_inputs: BatchFeature):
super().__init__({"model_inputs": model_inputs})

self.model_inputs = model_inputs

def to(self, device: torch.types.Device):
return NestedInputs(self.model_inputs.to(device))


# The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
Expand All @@ -41,6 +28,10 @@ def to(self, device: torch.types.Device):
models = ["openbmb/MiniCPM-Llama3-V-2_5"]


def _wrap_inputs(hf_inputs: BatchEncoding) -> BatchEncoding:
return BatchEncoding({"model_inputs": hf_inputs})


def trunc_hf_output(hf_output: Tuple[List[int], str,
Optional[SampleLogprobs]]):
output_ids, output_str, out_logprobs = hf_output
Expand Down Expand Up @@ -105,11 +96,8 @@ def run_test(
for prompts, images in inputs_per_image
]

with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():
hf_processor = hf_model.processor
hf_model.processor = lambda **kw: NestedInputs(
hf_processor(**kw) # type: ignore
)
hf_model = hf_runner(model, dtype=dtype, postprocess_inputs=_wrap_inputs)
with hf_model, torch.no_grad():
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
Expand Down Expand Up @@ -224,11 +212,8 @@ def run_multi_image_test(
for prompts, images in inputs_per_case
]

with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():
hf_processor = hf_model.processor
hf_model.processor = lambda **kw: NestedInputs(
hf_processor(**kw) # type: ignore
)
hf_model = hf_runner(model, dtype=dtype, postprocess_inputs=_wrap_inputs)
with hf_model, torch.no_grad():
hf_outputs_per_case = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
Expand Down
Loading

0 comments on commit 7025b11

Please sign in to comment.