diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index e669ce4db299d..77c4f6aa927e4 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -13,6 +13,7 @@ from vllm import LLM, SamplingParams from vllm.engine.arg_utils import EngineArgs from vllm.inputs import PromptType +from vllm.sampling_params import BeamSearchParams from vllm.utils import FlexibleArgumentParser @@ -40,6 +41,20 @@ def main(args: argparse.Namespace): "prompt_token_ids": batch } for batch in dummy_prompt_token_ids.tolist()] + def llm_generate(): + if not args.use_beam_search: + llm.generate(dummy_prompts, + sampling_params=sampling_params, + use_tqdm=False) + else: + llm.beam_search( + dummy_prompts, + BeamSearchParams( + beam_width=args.n, + max_tokens=args.output_len, + ignore_eos=True, + )) + def run_to_completion(profile_dir: Optional[str] = None): if profile_dir: with torch.profiler.profile( @@ -49,15 +64,11 @@ def run_to_completion(profile_dir: Optional[str] = None): ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: - llm.generate(dummy_prompts, - sampling_params=sampling_params, - use_tqdm=False) + llm_generate() print(p.key_averages().table(sort_by="self_cuda_time_total")) else: start_time = time.perf_counter() - llm.generate(dummy_prompts, - sampling_params=sampling_params, - use_tqdm=False) + llm_generate() end_time = time.perf_counter() latency = end_time - start_time return latency diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index e48fd1a4fa5e9..acb4db85632a8 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -21,7 +21,7 @@ parse_chat_messages, resolve_chat_template_content_format) from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt -from vllm.inputs.parse import parse_and_batch_prompt +from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding.guided_fields import ( @@ -457,7 +457,7 @@ def generate( def beam_search( self, - prompts: List[Union[str, List[int]]], + prompts: List[Union[TokensPrompt, TextPrompt]], params: BeamSearchParams, ) -> List[BeamSearchOutput]: """ @@ -493,8 +493,10 @@ def sort_beams_key(x: BeamSearchSequence) -> float: instances: List[BeamSearchInstance] = [] for prompt in prompts: - prompt_tokens = prompt if isinstance( - prompt, list) else tokenizer.encode(prompt) + if is_token_prompt(prompt): + prompt_tokens = prompt["prompt_token_ids"] + else: + prompt_tokens = tokenizer.encode(prompt["prompt"]) instances.append(BeamSearchInstance(prompt_tokens)) for _ in range(max_tokens):