Skip to content

Commit

Permalink
[Bugfix][Frontend] Reject guided decoding in multistep mode (vllm-pro…
Browse files Browse the repository at this point in the history
…ject#9892)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
  • Loading branch information
joerunde authored Nov 1, 2024
1 parent b63c64d commit 031a799
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/source/serving/compatibility_matrix.rst
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ Feature x Feature
- ✅
- ✅
- ✅
- `<https://github.com/vllm-project/vllm/issues/8985>`__
- `<https://github.com/vllm-project/vllm/issues/9893>`__
- ?
- ✅
- ✅
Expand Down
20 changes: 20 additions & 0 deletions tests/entrypoints/openai/test_prompt_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,23 @@ async def test_out_of_vocab_token_ids():
prompt=[999999],
max_tokens=5,
temperature=0.0)


@pytest.mark.asyncio
async def test_reject_multistep_with_guided_decoding():
model_name = "gpt2"
server_args = ["--enforce-eager", "--num-scheduler-steps", "8"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()

with pytest.raises(openai.BadRequestError,
match=re.compile(
'.*Guided decoding .* multi-step decoding.*')):
await client.completions.create(
model=model_name,
prompt="Hello",
max_tokens=5,
temperature=0.0,
extra_body={"response_format": {
"type": "json_object"
}})
7 changes: 7 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,13 @@ def add_request(
raise ValueError(f"Got priority {priority} but "
"Priority scheduling is not enabled.")

if isinstance(params, SamplingParams) \
and (params.guided_decoding or params.logits_processors) \
and self.scheduler_config.num_scheduler_steps > 1:
raise ValueError(
"Guided decoding and logits processors are not supported "
"in multi-step decoding")

if arrival_time is None:
arrival_time = time.time()

Expand Down
4 changes: 2 additions & 2 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,8 @@ def __repr__(self) -> str:
f"skip_special_tokens={self.skip_special_tokens}, "
"spaces_between_special_tokens="
f"{self.spaces_between_special_tokens}, "
f"truncate_prompt_tokens={self.truncate_prompt_tokens}), "
f"guided_decoding={self.guided_decoding}")
f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
f"guided_decoding={self.guided_decoding})")


class BeamSearchParams(
Expand Down

0 comments on commit 031a799

Please sign in to comment.