Skip to content

Commit

Permalink
added assert which fails when beam search is performed with multi-ste…
Browse files Browse the repository at this point in the history
…p; removed beam search + multi-step fallback tests
  • Loading branch information
abf149 committed Oct 7, 2024
1 parent 3ed32e2 commit fe12a95
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 27 deletions.
8 changes: 2 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,13 +741,11 @@ def generate_greedy(
prompts: List[str],
max_tokens: int,
best_of: Optional[int] = None,
use_beam_search: bool = False,
images: Optional[PromptImageInput] = None,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
best_of=best_of,
use_beam_search=use_beam_search)
best_of=best_of)
outputs = self.generate(prompts, greedy_params, images=images)
return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]
Expand All @@ -759,7 +757,6 @@ def generate_greedy_logprobs(
num_logprobs: int,
num_prompt_logprobs: Optional[int] = None,
best_of: Optional[int] = None,
use_beam_search: bool = False,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
Expand All @@ -772,8 +769,7 @@ def generate_greedy_logprobs(
logprobs=num_logprobs,
prompt_logprobs=num_prompt_logprobs,
stop_token_ids=stop_token_ids,
best_of=best_of,
use_beam_search=use_beam_search)
best_of=best_of)

return self.generate_w_logprobs(prompts,
greedy_logprobs_params,
Expand Down
36 changes: 15 additions & 21 deletions tests/multi_step/test_correctness_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,39 +210,35 @@ def test_multi_step_llm_w_prompt_logprobs(
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("num_logprobs", [None, 5])
@pytest.mark.parametrize("max_output_len", [7])
@pytest.mark.parametrize("n_best_of_use_beam_search", [
(1, 2, False),
(2, 2, False),
(1, 3, False),
(2, 3, False),
(3, 3, False),
(1, 1, True),
(1, 2, True),
(2, 2, True),
@pytest.mark.parametrize("n_best_of", [
(1, 2),
(2, 2),
(1, 3),
(2, 3),
(3, 3),
(1, 1),
(1, 2),
(2, 2),
])
def test_multi_step_llm_best_of_beam_search_fallback(
hf_runner,
def test_multi_step_llm_best_of_fallback(
vllm_runner,
example_prompts,
model: str,
dtype: str,
tp_size: int,
max_tokens: int,
enforce_eager: int,
num_scheduler_steps: int,
num_prompts: int,
num_logprobs: Optional[int],
max_output_len: int,
n_best_of_use_beam_search: Tuple[int, int, bool],
n_best_of: Tuple[int, int, bool],
) -> None:
"""Test vLLM engine with multi-step & beam search enabled or best_of > 1
"""Test vLLM engine with multi-step & best_of > 1
Currently multi-step scheduling does not support best_of > 1 or beam search,
however the default behavior is for the engine to fall back on single-step
scheduling rather than failing.
Args:
hf_runner: HF transformers model runner fixture
vllm_runner: vLLM model runner fixture
example_prompts: test fixture providing example prompts
model: model under test (same for single- and multi-step engines)
Expand All @@ -265,16 +261,14 @@ def test_multi_step_llm_best_of_beam_search_fallback(
prompts = prompts[:num_prompts]
assert len(prompts) == num_prompts

n = n_best_of_use_beam_search[0]
best_of = n_best_of_use_beam_search[1]
use_beam_search = n_best_of_use_beam_search[2]
n = n_best_of[0]
best_of = n_best_of[1]
sampling_params = SamplingParams(
max_tokens=max_output_len,
ignore_eos=True,
temperature=0.0 if use_beam_search else 1.0,
temperature=1.0,
n=n,
best_of=best_of,
use_beam_search=use_beam_search,
seed=42,
)

Expand Down
4 changes: 4 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,10 @@ def beam_search(
penalty, and stopping criteria, etc.?
"""

assert not self.llm_engine.scheduler_config.is_multi_step, (
"Currently beam search is not supported in combination with "
"multi-step scheduling.")

beam_width = params.beam_width
max_tokens = params.max_tokens
temperature = params.temperature
Expand Down

0 comments on commit fe12a95

Please sign in to comment.