Skip to content

Commit

Permalink
[Bugfix] Add validation for seed (vllm-project#4529)
Browse files Browse the repository at this point in the history
  • Loading branch information
sasha0552 authored May 1, 2024
1 parent 24bb4fe commit c47ba4a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
20 changes: 20 additions & 0 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# and debugging.
import ray
import requests
import torch
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
from openai import BadRequestError
Expand Down Expand Up @@ -870,5 +871,24 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
assert len(logprobs.tokens) > 5


async def test_long_seed(server, client: openai.AsyncOpenAI):
for seed in [
torch.iinfo(torch.long).min - 1,
torch.iinfo(torch.long).max + 1
]:
with pytest.raises(BadRequestError) as exc_info:
await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role": "system",
"content": "You are a helpful assistant.",
}],
temperature=0,
seed=seed)

assert ("greater_than_equal" in exc_info.value.message
or "less_than_equal" in exc_info.value.message)


if __name__ == "__main__":
pytest.main([__file__])
8 changes: 6 additions & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0
response_format: Optional[ResponseFormat] = None
seed: Optional[int] = None
seed: Optional[int] = Field(None,
ge=torch.iinfo(torch.long).min,
le=torch.iinfo(torch.long).max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
temperature: Optional[float] = 0.7
Expand Down Expand Up @@ -228,7 +230,9 @@ class CompletionRequest(OpenAIBaseModel):
max_tokens: Optional[int] = 16
n: int = 1
presence_penalty: Optional[float] = 0.0
seed: Optional[int] = None
seed: Optional[int] = Field(None,
ge=torch.iinfo(torch.long).min,
le=torch.iinfo(torch.long).max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
suffix: Optional[str] = None
Expand Down

0 comments on commit c47ba4a

Please sign in to comment.