Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Guided gen support #9

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.guided_decoding import (
GuidedDecodingRequest, get_local_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
Expand Down Expand Up @@ -1004,7 +1008,9 @@ async def generate(
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.

Expand Down Expand Up @@ -1070,6 +1076,15 @@ async def generate(
>>> # Process and return the final output
>>> ...
"""
if isinstance(sampling_params, list):
sampling_params = [
self._add_guided_processor(param, guided_options_request)
if isinstance(param, SamplingParams) else param
for param in sampling_params
]
elif isinstance(sampling_params, SamplingParams):
sampling_params = self._add_guided_processor(sampling_params, guided_options_request)

async for output in await self.add_request(
request_id,
inputs,
Expand All @@ -1079,6 +1094,24 @@ async def generate(
prompt_adapter_request=prompt_adapter_request,
):
yield LLMEngine.validate_output(output, RequestOutput)

def _add_guided_processor(
self,
params: SamplingParams,
guided_options: Optional[GuidedDecodingRequest] = None):
if guided_options:
if guided_options.guided_decoding_backend is None:
decoding_config = self.get_decoding_config()
guided_options.guided_decoding_backend = (
decoding_config.guided_decoding_backend)
guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa
guided_options.guided_decoding_backend, guided_options,
self.engine.get_tokenizer_group(TokenizerGroup).tokenizer)
if guided_logits_processor:
if params.logits_processors is None:
params.logits_processors = []
params.logits_processors.append(guided_logits_processor)
return params

async def encode(
self,
Expand Down
Loading