From b3c86471c91c1411112088e91996e638771f20c9 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Fri, 6 Sep 2024 15:41:56 -0700 Subject: [PATCH 1/3] Support guided generation with AsyncLLMEngine --- vllm/engine/async_llm_engine.py | 35 ++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7fe8053fffb7b..037891a9ed2cf 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -23,7 +23,11 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.guided_decoding import ( + GuidedDecodingRequest, get_local_guided_decoding_logits_processor) +from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams @@ -1004,7 +1008,9 @@ async def generate( request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None ) -> AsyncGenerator[RequestOutput, None]: """Generate outputs for a request. @@ -1070,6 +1076,15 @@ async def generate( >>> # Process and return the final output >>> ... """ + if isinstance(sampling_params, list): + sampling_params = [ + self._add_guided_processor(param, guided_options) + if isinstance(param, SamplingParams) else param + for param in sampling_params + ] + elif isinstance(sampling_params, SamplingParams): + sampling_params = self._add_guided_processor(sampling_params, guided_options) + async for output in await self.add_request( request_id, inputs, @@ -1079,6 +1094,24 @@ async def generate( prompt_adapter_request=prompt_adapter_request, ): yield LLMEngine.validate_output(output, RequestOutput) + + def _add_guided_processor( + self, + params: SamplingParams, + guided_options: Optional[GuidedDecodingRequest] = None): + if guided_options: + if guided_options.guided_decoding_backend is None: + decoding_config = self.get_decoding_config() + guided_options.guided_decoding_backend = ( + decoding_config.guided_decoding_backend) + guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa + guided_options.guided_decoding_backend, guided_options, + self.get_tokenizer_group(TokenizerGroup).tokenizer) + if guided_logits_processor: + if params.logits_processors is None: + params.logits_processors = [] + params.logits_processors.append(guided_logits_processor) + return params async def encode( self, From 8b82f03e9ba6a8c4746d77b21b2c13a56be6736c Mon Sep 17 00:00:00 2001 From: Dhruva Bansal Date: Tue, 10 Sep 2024 01:09:04 +0000 Subject: [PATCH 2/3] fix --- vllm/engine/async_llm_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 037891a9ed2cf..23bae556d8fef 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1078,12 +1078,12 @@ async def generate( """ if isinstance(sampling_params, list): sampling_params = [ - self._add_guided_processor(param, guided_options) + self._add_guided_processor(param, guided_options_request) if isinstance(param, SamplingParams) else param for param in sampling_params ] elif isinstance(sampling_params, SamplingParams): - sampling_params = self._add_guided_processor(sampling_params, guided_options) + sampling_params = self._add_guided_processor(sampling_params, guided_options_request) async for output in await self.add_request( request_id, From 55710cd14bbc02d88f010531bc7fb572f842a985 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 9 Sep 2024 18:15:30 -0700 Subject: [PATCH 3/3] Use self.engine for accessing tokenizer group --- vllm/engine/async_llm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 23bae556d8fef..7041fba870f4c 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1106,7 +1106,7 @@ def _add_guided_processor( decoding_config.guided_decoding_backend) guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa guided_options.guided_decoding_backend, guided_options, - self.get_tokenizer_group(TokenizerGroup).tokenizer) + self.engine.get_tokenizer_group(TokenizerGroup).tokenizer) if guided_logits_processor: if params.logits_processors is None: params.logits_processors = []