diff --git a/vllm/beam_search.py b/vllm/beam_search.py index 04624b8b94432..1b48538734dae 100644 --- a/vllm/beam_search.py +++ b/vllm/beam_search.py @@ -1,5 +1,7 @@ from dataclasses import dataclass -from typing import List, Optional +from typing import Dict, List, Optional + +from vllm.sequence import Logprob @dataclass @@ -11,6 +13,7 @@ class BeamSearchSequence: """ # The tokens includes the prompt. tokens: List[int] + logprobs: List[Dict[int, Logprob]] cum_logprob: float = 0.0 text: Optional[str] = None @@ -28,7 +31,7 @@ class BeamSearchInstance: def __init__(self, prompt_tokens: List[int]): self.beams: List[BeamSearchSequence] = [ - BeamSearchSequence(tokens=prompt_tokens) + BeamSearchSequence(tokens=prompt_tokens, logprobs=[]) ] self.completed: List[BeamSearchSequence] = [] diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 16ceddf13511c..5c504e0f0217d 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -59,7 +59,7 @@ def generate( async def beam_search( self, - prompt: Union[PromptType, List[int]], + prompt: Union[str, List[int]], request_id: str, params: BeamSearchParams, ) -> AsyncGenerator[RequestOutput, None]: @@ -71,9 +71,13 @@ async def beam_search( length_penalty = params.length_penalty tokenizer = await self.get_tokenizer(lora_request=None) - tokenizedPrompt = prompt if isinstance( - prompt, list) else tokenizer.encode(prompt) - tokenizedLength = len(tokenizedPrompt) + if isinstance(prompt, str): + tokenized_prompt = tokenizer.encode(prompt) + prompt_text = prompt + else: + tokenized_prompt = prompt + prompt_text = None + tokenized_length = len(tokenized_prompt) sort_beams_key = create_sort_beams_key_function( tokenizer.eos_token_id, length_penalty) @@ -81,7 +85,11 @@ async def beam_search( beam_search_params = SamplingParams(logprobs=2 * beam_width, max_tokens=1, temperature=temperature) - all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)] + all_beams = [ + BeamSearchSequence(tokens=tokenized_prompt, + logprobs=[], + cum_logprob=0) + ] completed = [] for _ in range(max_tokens): @@ -114,6 +122,7 @@ async def beam_search( for token_id, logprob_obj in logprobs.items(): new_beam = BeamSearchSequence( tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + [logprobs], cum_logprob=current_beam.cum_logprob + logprob_obj.logprob) @@ -131,22 +140,22 @@ async def beam_search( best_beams = sorted_completed[:beam_width] for beam in best_beams: - beam.text = tokenizer.decode(beam.tokens[tokenizedLength:]) + beam.text = tokenizer.decode(beam.tokens[tokenized_length:]) beam_search_output = RequestOutput( request_id=request_id, - prompt=prompt, + prompt=prompt_text, outputs=[ CompletionOutput( text=beam.text, cumulative_logprob=beam.cum_logprob, - token_ids=beam.tokens, + token_ids=beam.tokens[tokenized_length:], index=i, - logprobs=beam.cum_logprob, + logprobs=beam.logprobs, ) for (i, beam) in enumerate(best_beams) ], finished=True, - prompt_token_ids=tokenizedPrompt, + prompt_token_ids=tokenized_prompt, prompt_logprobs=None) yield beam_search_output diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 2010381076c7d..088ec35798de8 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -433,6 +433,7 @@ def sort_beams_key(x: BeamSearchSequence) -> float: for token_id, logprob_obj in logprobs.items(): new_beam = BeamSearchSequence( tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + [logprobs], cum_logprob=current_beam.cum_logprob + logprob_obj.logprob) diff --git a/vllm/outputs.py b/vllm/outputs.py index 15cb8d53186df..07650241cb638 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -4,7 +4,6 @@ from typing import Sequence as GenericSequence from typing import Union -from vllm.inputs import PromptType from vllm.lora.request import LoRARequest from vllm.sampling_params import RequestOutputKind from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, @@ -93,7 +92,7 @@ class RequestOutput: def __init__( self, request_id: str, - prompt: Optional[PromptType], + prompt: Optional[str], prompt_token_ids: Optional[List[int]], prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput],