diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 3361fdefc960..7778732dd8be 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -420,6 +420,7 @@ async def add_request_async( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> None: ... @@ -433,6 +434,7 @@ async def add_request_async( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> None: ... @@ -449,6 +451,7 @@ async def add_request_async( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, *, inputs: Optional[PromptType] = None, # DEPRECATED ) -> None: @@ -460,6 +463,9 @@ async def add_request_async( if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") + if priority != 0 and not self.scheduler_config.policy == "priority": + raise ValueError(f"Got priority {priority} but " + "Priority scheduling is not enabled.") if arrival_time is None: arrival_time = time.time() @@ -479,6 +485,7 @@ async def add_request_async( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, + priority=priority, ) async def check_health_async(self) -> None: @@ -829,6 +836,7 @@ def add_request( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> Coroutine[None, None, AsyncGenerator[Union[ RequestOutput, EmbeddingRequestOutput], None]]: ... @@ -843,6 +851,7 @@ def add_request( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> Coroutine[None, None, AsyncGenerator[Union[ RequestOutput, EmbeddingRequestOutput], None]]: ... @@ -860,6 +869,7 @@ async def add_request( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, *, inputs: Optional[PromptType] = None, # DEPRECATED ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: @@ -877,6 +887,11 @@ async def add_request( "error that caused the background loop to stop " "(AsyncEngineDeadError).") + if (priority != 0 + and not self.engine.scheduler_config.policy == "priority"): + raise ValueError(f"Got priority {priority} but " + "Priority scheduling is not enabled.") + stream = self._request_tracker.add_request( request_id, verbose=self.log_requests, @@ -885,7 +900,9 @@ async def add_request( arrival_time=arrival_time or time.time(), lora_request=lora_request, trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ) return stream.generator() @@ -896,7 +913,8 @@ async def generate( request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> AsyncGenerator[RequestOutput, None]: """Generate outputs for a request. @@ -913,6 +931,8 @@ async def generate( trace_headers: OpenTelemetry trace headers. prompt_adapter_request: Prompt Adapter request to use for generation, if any. + priority: The priority of the request. + Only applicable with priority scheduling. Yields: The output `RequestOutput` objects from the LLMEngine @@ -968,6 +988,7 @@ async def generate( lora_request=lora_request, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, + priority=priority, ): yield LLMEngine.validate_output(output, RequestOutput) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 19f88ac3e7c5..e3cd822f648f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -796,7 +796,7 @@ def add_request( raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") - if priority > 0 and not self.scheduler_config.policy == "priority": + if priority != 0 and not self.scheduler_config.policy == "priority": raise ValueError(f"Got priority {priority} but " "Priority scheduling is not enabled.")