diff --git a/tests/conftest.py b/tests/conftest.py index 5f82a5d6aa36..551b0e674805 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -741,13 +741,11 @@ def generate_greedy( prompts: List[str], max_tokens: int, best_of: Optional[int] = None, - use_beam_search: bool = False, images: Optional[PromptImageInput] = None, ) -> List[Tuple[List[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens, - best_of=best_of, - use_beam_search=use_beam_search) + best_of=best_of) outputs = self.generate(prompts, greedy_params, images=images) return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] @@ -759,7 +757,6 @@ def generate_greedy_logprobs( num_logprobs: int, num_prompt_logprobs: Optional[int] = None, best_of: Optional[int] = None, - use_beam_search: bool = False, images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, @@ -772,8 +769,7 @@ def generate_greedy_logprobs( logprobs=num_logprobs, prompt_logprobs=num_prompt_logprobs, stop_token_ids=stop_token_ids, - best_of=best_of, - use_beam_search=use_beam_search) + best_of=best_of) return self.generate_w_logprobs(prompts, greedy_logprobs_params, diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index 5b08741e67e1..cbb7ada2fa54 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -210,39 +210,35 @@ def test_multi_step_llm_w_prompt_logprobs( @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) @pytest.mark.parametrize("num_logprobs", [None, 5]) @pytest.mark.parametrize("max_output_len", [7]) -@pytest.mark.parametrize("n_best_of_use_beam_search", [ - (1, 2, False), - (2, 2, False), - (1, 3, False), - (2, 3, False), - (3, 3, False), - (1, 1, True), - (1, 2, True), - (2, 2, True), +@pytest.mark.parametrize("n_best_of", [ + (1, 2), + (2, 2), + (1, 3), + (2, 3), + (3, 3), + (1, 1), + (1, 2), + (2, 2), ]) -def test_multi_step_llm_best_of_beam_search_fallback( - hf_runner, +def test_multi_step_llm_best_of_fallback( vllm_runner, example_prompts, model: str, dtype: str, tp_size: int, - max_tokens: int, enforce_eager: int, num_scheduler_steps: int, num_prompts: int, - num_logprobs: Optional[int], max_output_len: int, - n_best_of_use_beam_search: Tuple[int, int, bool], + n_best_of: Tuple[int, int, bool], ) -> None: - """Test vLLM engine with multi-step & beam search enabled or best_of > 1 + """Test vLLM engine with multi-step & best_of > 1 Currently multi-step scheduling does not support best_of > 1 or beam search, however the default behavior is for the engine to fall back on single-step scheduling rather than failing. Args: - hf_runner: HF transformers model runner fixture vllm_runner: vLLM model runner fixture example_prompts: test fixture providing example prompts model: model under test (same for single- and multi-step engines) @@ -265,16 +261,14 @@ def test_multi_step_llm_best_of_beam_search_fallback( prompts = prompts[:num_prompts] assert len(prompts) == num_prompts - n = n_best_of_use_beam_search[0] - best_of = n_best_of_use_beam_search[1] - use_beam_search = n_best_of_use_beam_search[2] + n = n_best_of[0] + best_of = n_best_of[1] sampling_params = SamplingParams( max_tokens=max_output_len, ignore_eos=True, - temperature=0.0 if use_beam_search else 1.0, + temperature=1.0, n=n, best_of=best_of, - use_beam_search=use_beam_search, seed=42, ) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 439f3769f9fb..2a455df048fe 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -401,6 +401,10 @@ def beam_search( penalty, and stopping criteria, etc.? """ + assert not self.llm_engine.scheduler_config.is_multi_step, ( + "Currently beam search is not supported in combination with " + "multi-step scheduling.") + beam_width = params.beam_width max_tokens = params.max_tokens temperature = params.temperature