diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7fe8053fffb7b..7041fba870f4c 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -23,7 +23,11 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.guided_decoding import ( + GuidedDecodingRequest, get_local_guided_decoding_logits_processor) +from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams @@ -1004,7 +1008,9 @@ async def generate( request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None ) -> AsyncGenerator[RequestOutput, None]: """Generate outputs for a request. @@ -1070,6 +1076,15 @@ async def generate( >>> # Process and return the final output >>> ... """ + if isinstance(sampling_params, list): + sampling_params = [ + self._add_guided_processor(param, guided_options_request) + if isinstance(param, SamplingParams) else param + for param in sampling_params + ] + elif isinstance(sampling_params, SamplingParams): + sampling_params = self._add_guided_processor(sampling_params, guided_options_request) + async for output in await self.add_request( request_id, inputs, @@ -1079,6 +1094,24 @@ async def generate( prompt_adapter_request=prompt_adapter_request, ): yield LLMEngine.validate_output(output, RequestOutput) + + def _add_guided_processor( + self, + params: SamplingParams, + guided_options: Optional[GuidedDecodingRequest] = None): + if guided_options: + if guided_options.guided_decoding_backend is None: + decoding_config = self.get_decoding_config() + guided_options.guided_decoding_backend = ( + decoding_config.guided_decoding_backend) + guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa + guided_options.guided_decoding_backend, guided_options, + self.engine.get_tokenizer_group(TokenizerGroup).tokenizer) + if guided_logits_processor: + if params.logits_processors is None: + params.logits_processors = [] + params.logits_processors.append(guided_logits_processor) + return params async def encode( self,