Skip to content

Commit

Permalink
[Frontend] re-enable multi-modality input in the new beam search impl…
Browse files Browse the repository at this point in the history
…ementation (vllm-project#9427)

Signed-off-by: Qishuai Ferdinandzhong@gmail.com
  • Loading branch information
FerdinandZhong authored Oct 29, 2024
1 parent eae3d48 commit ef7865b
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 40 deletions.
71 changes: 71 additions & 0 deletions tests/entrypoints/openai/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,42 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
assert message.content is not None and len(message.content) >= 0


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI,
model_name: str,
image_url: str):
messages = [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "text",
"text": "What's in this image?"
},
],
}]

chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
n=2,
max_tokens=10,
logprobs=True,
top_logprobs=5,
extra_body=dict(use_beam_search=True))
assert len(chat_completion.choices) == 2
assert chat_completion.choices[
0].message.content != chat_completion.choices[1].message.content


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
Expand Down Expand Up @@ -162,6 +198,41 @@ async def test_single_chat_session_image_base64encoded(
assert message.content is not None and len(message.content) >= 0


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_single_chat_session_image_base64encoded_beamsearch(
client: openai.AsyncOpenAI, model_name: str, image_url: str,
base64_encoded_image: Dict[str, str]):

messages = [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url":
f"data:image/jpeg;base64,{base64_encoded_image[image_url]}"
}
},
{
"type": "text",
"text": "What's in this image?"
},
],
}]
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
n=2,
max_tokens=10,
extra_body=dict(use_beam_search=True))
assert len(chat_completion.choices) == 2
assert chat_completion.choices[
0].message.content != chat_completion.choices[1].message.content


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
Expand Down
9 changes: 8 additions & 1 deletion vllm/beam_search.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from dataclasses import dataclass
from typing import Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from vllm.sequence import Logprob

if TYPE_CHECKING:
from vllm.multimodal import MultiModalDataDict


@dataclass
class BeamSearchSequence:
Expand All @@ -16,6 +19,10 @@ class BeamSearchSequence:
logprobs: List[Dict[int, Logprob]]
cum_logprob: float = 0.0
text: Optional[str] = None
finish_reason: Optional[str] = None
stop_reason: Union[int, str, None] = None
multi_modal_data: Optional["MultiModalDataDict"] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None


@dataclass
Expand Down
88 changes: 57 additions & 31 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
Expand Down Expand Up @@ -59,7 +60,8 @@ def generate(

async def beam_search(
self,
prompt: Union[str, List[int]],
prompt: Union[PromptType, List[int]],
model_config: ModelConfig,
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:
Expand All @@ -69,32 +71,40 @@ async def beam_search(
ignore_eos = params.ignore_eos
temperature = params.temperature
length_penalty = params.length_penalty
include_stop_str_in_output = params.include_stop_str_in_output

tokenizer = await self.get_tokenizer(lora_request=None)
if isinstance(prompt, str):
tokenized_prompt = tokenizer.encode(prompt)
prompt_text = prompt
else:
tokenized_prompt = prompt
prompt_text = None
tokenized_length = len(tokenized_prompt)
tokenizer = await self.get_tokenizer()
input_preprocessor = InputPreprocessor(model_config, tokenizer)

(prompt_text, prompt_token_ids, multi_modal_data,
mm_processor_kwargs) = input_preprocessor._extract_prompt_components(
prompt,
request_id=request_id,
)
tokenized_length = len(prompt_token_ids)

sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty)

beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
beam_search_params = SamplingParams(
logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature,
)
all_beams = [
BeamSearchSequence(tokens=tokenized_prompt,
BeamSearchSequence(tokens=prompt_token_ids,
cum_logprob=0,
logprobs=[],
cum_logprob=0)
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs)
]
completed = []

for _ in range(max_tokens):
prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens)
TokensPrompt(prompt_token_ids=beam.tokens,
multi_modal_data=beam.multi_modal_data,
mm_processor_kwargs=beam.mm_processor_kwargs)
for beam in all_beams
]

Expand All @@ -120,17 +130,31 @@ async def beam_search(
if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)

if token_id == tokenizer.eos_token_id and \
not ignore_eos:
completed.append(new_beam)
completed.append(
BeamSearchSequence(
tokens=current_beam.tokens +
[token_id] if include_stop_str_in_output
else current_beam.tokens,
logprobs=current_beam.logprobs +
[logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob,
finish_reason="stop",
stop_reason=tokenizer.eos_token_id))
else:
new_beams.append(new_beam)
new_beams.append(
BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs +
[logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob,
multi_modal_data=current_beam.
multi_modal_data,
mm_processor_kwargs=current_beam.
mm_processor_kwargs))

sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width]
Expand All @@ -151,16 +175,18 @@ async def beam_search(
request_id=request_id,
prompt=prompt_text,
outputs=[
CompletionOutput(
text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens[tokenized_length:],
index=i,
logprobs=beam.logprobs,
) for (i, beam) in enumerate(best_beams)
CompletionOutput(text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens[tokenized_length:],
index=i,
logprobs=beam.logprobs,
finish_reason=beam.finish_reason if
beam.finish_reason is not None else "length",
stop_reason=beam.stop_reason)
for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=tokenized_prompt,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None)

yield beam_search_output
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def to_beam_search_params(self,
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
)
include_stop_str_in_output=self.include_stop_str_in_output)

def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
Expand Down Expand Up @@ -606,7 +606,7 @@ def to_beam_search_params(self,
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
)
include_stop_str_in_output=self.include_stop_str_in_output)

def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
Expand Down
7 changes: 4 additions & 3 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,10 @@ async def create_chat_completion(

if isinstance(sampling_params, BeamSearchParams):
result_generator = self.engine_client.beam_search(
engine_inputs['prompt_token_ids'],
request_id,
sampling_params,
prompt=engine_inputs,
model_config=self.model_config,
request_id=request_id,
params=sampling_params,
)
else:
result_generator = self.engine_client.generate(
Expand Down
10 changes: 7 additions & 3 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,13 @@ async def create_completion(

if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search(
prompt_inputs["prompt_token_ids"],
request_id_item,
sampling_params,
prompt={
"prompt_token_ids":
prompt_inputs["prompt_token_ids"]
},
model_config=self.model_config,
request_id=request_id,
params=sampling_params,
)
else:
generator = self.engine_client.generate(
Expand Down
1 change: 1 addition & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,3 +500,4 @@ class BeamSearchParams(
ignore_eos: bool = False
temperature: float = 0.0
length_penalty: float = 1.0
include_stop_str_in_output: bool = False

0 comments on commit ef7865b

Please sign in to comment.