From 197b4484a3fba4a98921f903d6242677f97c63db Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Wed, 27 Nov 2024 21:02:27 +0200 Subject: [PATCH] [Bugfix][Mamba] Fix Multistep on Mamba-like models (#10705) Signed-off-by: mzusman --- .../decoder_only/language/test_jamba.py | 38 +++++++++++++++++++ .../decoder_only/language/test_mamba.py | 36 ++++++++++++++++++ vllm/engine/async_llm_engine.py | 7 +++- vllm/engine/llm_engine.py | 7 +++- 4 files changed, 84 insertions(+), 4 deletions(-) diff --git a/tests/models/decoder_only/language/test_jamba.py b/tests/models/decoder_only/language/test_jamba.py index 6542689c3f277..87a05b3011393 100644 --- a/tests/models/decoder_only/language/test_jamba.py +++ b/tests/models/decoder_only/language/test_jamba.py @@ -275,6 +275,44 @@ def test_state_cleanup( "could be related to finished_requests_ids") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_multistep( + vllm_runner, + model: str, + dtype: str, + example_prompts, +) -> None: + # This test is verifying that multistep works correctly + #on mamba-like models + with vllm_runner(model, num_scheduler_steps=8, + max_num_seqs=2) as vllm_model: + vllm_model.generate_greedy([example_prompts[0]] * 10, 1) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [64]) +def test_multistep_correctness(vllm_runner, model: str, dtype: str, + max_tokens: int, example_prompts) -> None: + with vllm_runner(model, num_scheduler_steps=8, + max_num_seqs=2) as vllm_model: + vllm_outputs_multistep = vllm_model.generate_greedy( + example_prompts, max_tokens) + + with vllm_runner(model, num_scheduler_steps=1, + max_num_seqs=2) as vllm_model: + vllm_outputs_single_step = vllm_model.generate_greedy( + example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_outputs_multistep, + outputs_1_lst=vllm_outputs_single_step, + name_0="vllm_outputs_multistep", + name_1="vllm_outputs_single_step", + ) + + @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py index 78eab8d5354fd..01e208347bff4 100644 --- a/tests/models/decoder_only/language/test_mamba.py +++ b/tests/models/decoder_only/language/test_mamba.py @@ -283,3 +283,39 @@ def test_state_cleanup( except ValueError: pytest.fail("Mamba inner state wasn't cleaned up between states, " "could be related to finished_requests_ids") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_multistep( + vllm_runner, + model: str, + dtype: str, + example_prompts, +) -> None: + with vllm_runner(model, num_scheduler_steps=8, + max_num_seqs=2) as vllm_model: + vllm_model.generate_greedy([example_prompts[0]] * 10, 1) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [64]) +def test_multistep_correctness(vllm_runner, model: str, dtype: str, + max_tokens: int, example_prompts) -> None: + with vllm_runner(model, num_scheduler_steps=8, + max_num_seqs=2) as vllm_model: + vllm_outputs_multistep = vllm_model.generate_greedy( + example_prompts, max_tokens) + + with vllm_runner(model, num_scheduler_steps=1, + max_num_seqs=2) as vllm_model: + vllm_outputs_single_step = vllm_model.generate_greedy( + example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_outputs_multistep, + outputs_1_lst=vllm_outputs_single_step, + name_0="vllm_outputs_multistep", + name_1="vllm_outputs_single_step", + ) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 3224577c567f8..31a15b04314d5 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -300,6 +300,9 @@ async def step_async( ctx.seq_group_metadata_list = seq_group_metadata_list ctx.scheduler_outputs = scheduler_outputs + finished_requests_ids = self.scheduler[ + virtual_engine].get_and_reset_finished_requests_ids() + # Maybe switch from async mode to sync mode if not allow_async_output_proc and len(ctx.output_queue) > 0: self._process_model_outputs(ctx=ctx) @@ -311,13 +314,13 @@ async def step_async( self._cache_scheduler_outputs_for_multi_step( virtual_engine, seq_group_metadata_list, scheduler_outputs, allow_async_output_proc) + else: + finished_requests_ids = list() assert seq_group_metadata_list is not None assert scheduler_outputs is not None if not scheduler_outputs.is_empty(): - finished_requests_ids = self.scheduler[ - virtual_engine].get_and_reset_finished_requests_ids() # Check if we have a cached last_output from the previous iteration. # For supporting PP this is probably the best way to pass the diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a4975cece9a81..ecc222f692c41 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1398,6 +1398,9 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: ctx.seq_group_metadata_list = seq_group_metadata_list ctx.scheduler_outputs = scheduler_outputs + finished_requests_ids = self.scheduler[ + virtual_engine].get_and_reset_finished_requests_ids() + # Maybe switch from async mode to sync mode if not allow_async_output_proc and len(ctx.output_queue) > 0: self._process_model_outputs(ctx=ctx) @@ -1409,13 +1412,13 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: self._cache_scheduler_outputs_for_multi_step( virtual_engine, seq_group_metadata_list, scheduler_outputs, allow_async_output_proc) + else: + finished_requests_ids = list() assert seq_group_metadata_list is not None assert scheduler_outputs is not None if not scheduler_outputs.is_empty(): - finished_requests_ids = self.scheduler[ - virtual_engine].get_and_reset_finished_requests_ids() # Check if we have a cached last_output from the previous iteration. # For supporting PP this is probably the best way to pass the