Skip to content

Commit

Permalink
prefill only inputs (data & preprocessor)
Browse files Browse the repository at this point in the history
  • Loading branch information
noooop committed Oct 7, 2024
1 parent cb4946d commit e80d5f2
Show file tree
Hide file tree
Showing 9 changed files with 347 additions and 0 deletions.
Empty file added tests/inputs/__init__.py
Empty file.
Empty file.
81 changes: 81 additions & 0 deletions tests/inputs/prefill_only/test_input_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# mypy: ignore-errors
import pytest

from vllm.inputs.prefill_only.data import (TextOnlyInputs, TextPrompt,
TokensPrompt, ValidationError)
from vllm.inputs.prefill_only.preprocessor import TextInputProcessor

input_processor = TextInputProcessor()


@pytest.fixture(scope="session")
def request_id():
return "0"


def test_input_processor_1(request_id):
prompt = "test"
request = input_processor(request_id, prompt)

assert request.inputs == {"prompt": prompt}


def test_input_processor_2(request_id):
prompt = "test"
inputs = TextPrompt(prompt=prompt)
request = input_processor(request_id, inputs)

assert request.inputs == {"prompt": prompt}


def test_input_processor_3(request_id):
prompt_token_ids = [0]
inputs = TokensPrompt(prompt_token_ids=prompt_token_ids)
request = input_processor(request_id, inputs)

assert request.inputs == {"prompt_token_ids": prompt_token_ids}


def test_input_processor_4(request_id):
prompt = "test"
prompt_token_ids = [0]
inputs = TextOnlyInputs(prompt_token_ids=prompt_token_ids)
request = input_processor(request_id, inputs)

assert request.inputs == {"prompt_token_ids": prompt_token_ids}

inputs = TextOnlyInputs(prompt_token_ids=prompt_token_ids, prompt=prompt)
request = input_processor(request_id, inputs)

assert request.inputs == {
"prompt_token_ids": prompt_token_ids,
"prompt": prompt
}


def test_input_processor_5(request_id):
prompt = "test"
prompt_token_ids = [0]
inputs = {"prompt_token_ids": prompt_token_ids, "prompt": prompt}

request = input_processor(request_id, inputs)

assert request.inputs == inputs


def test_validation_error(request_id):
with pytest.raises(ValidationError):
inputs = {}
input_processor(request_id, inputs)

with pytest.raises(ValidationError):
inputs = {"foo": "bar"}
input_processor(request_id, inputs)

with pytest.raises(ValidationError):
inputs = 0
input_processor(request_id, inputs)

with pytest.raises(ValidationError):
inputs = 0.0
input_processor(request_id, inputs)
41 changes: 41 additions & 0 deletions tests/inputs/prefill_only/test_request_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest

from vllm.inputs.prefill_only.data import TextOnlyInputs, TokensPrompt
from vllm.inputs.prefill_only.preprocessor import (TextInputProcessor,
TextRequestProcessor)
from vllm.inputs.prefill_only.tokenizer import Tokenizer


@pytest.fixture(scope="session")
def request_id():
return "0"


TOKENIZER_NAMES = ["facebook/opt-125m", "gpt2"]


@pytest.mark.parametrize("tokenizer_name", TOKENIZER_NAMES)
def test_request_processor(request_id: str, tokenizer_name: str):

tokenizer = Tokenizer(tokenizer_name=tokenizer_name)
input_processor = TextInputProcessor()
request_processor = TextRequestProcessor(tokenizer)

prompt = "test"
request = input_processor(request_id, prompt)

assert request.inputs == {"prompt": prompt}

schedulable_request = request_processor(request)

assert isinstance(schedulable_request.inputs, TextOnlyInputs)
assert len(schedulable_request.inputs.prompt_token_ids) > 0

prompt_token_ids = [0]
request = input_processor(request_id,
TokensPrompt(prompt_token_ids=prompt_token_ids))

schedulable_request = request_processor(request)

assert isinstance(schedulable_request.inputs, TextOnlyInputs)
assert len(schedulable_request.inputs.prompt_token_ids) > 0
File renamed without changes.
Empty file.
68 changes: 68 additions & 0 deletions vllm/inputs/prefill_only/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Union


class Params:
pass


class Inputs:
pass


@dataclass
class TextPrompt(Inputs):
"""Schema for a text prompt."""

prompt: str
"""The input text to be tokenized before passing to the model."""


@dataclass
class TokensPrompt(Inputs):
"""Schema for a tokenized prompt."""

prompt_token_ids: List[int]
"""A list of token IDs to pass to the model."""


@dataclass
class TextOnlyInputs(Inputs):
prompt_token_ids: List[int]
"""The token IDs of the prompt."""

prompt: Optional[str] = None
"""
The original prompt text corresponding to the token IDs, if available.
"""


PromptInput = Union[str, Dict, TextPrompt, TokensPrompt, TextOnlyInputs]


@dataclass
class Request:
request_id: str
arrival_time: float


@dataclass
class TextRequest(Request):
inputs: Dict


class ValidationError(ValueError):
pass


class SchedulableRequest(Request):
pass


@dataclass
class TextSchedulableRequest(SchedulableRequest):
inputs: TextOnlyInputs

@property
def num_new_tokens(self):
return len(self.inputs.prompt_token_ids)
125 changes: 125 additions & 0 deletions vllm/inputs/prefill_only/preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import time
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, cast

from vllm.inputs.prefill_only.data import (Params, PromptInput, Request,
SchedulableRequest, TextOnlyInputs,
TextPrompt, TextRequest,
TextSchedulableRequest,
TokensPrompt, ValidationError)
from vllm.inputs.prefill_only.tokenizer import Tokenizer


class InputProcessor(ABC):
"""
Input(request_id, inputs, params, arrival_time) -> InputProcessor -> Request
"""

@abstractmethod
def __call__(self,
request_id: str,
inputs: Optional[Any] = None,
params: Optional[Params] = None,
arrival_time: Optional[float] = None) -> Request:
raise NotImplementedError

@classmethod
@abstractmethod
def from_engine(cls, engine):
raise NotImplementedError


class TextInputProcessor(InputProcessor):

def __call__(self,
request_id: str,
inputs: Optional[PromptInput] = None,
params: Optional[Params] = None,
arrival_time: Optional[float] = None) -> TextRequest:

if isinstance(inputs, str):
inputs = {"prompt": inputs}
elif isinstance(inputs, TextPrompt):
inputs = {"prompt": inputs.prompt}
elif isinstance(inputs, TokensPrompt):
inputs = {"prompt_token_ids": inputs.prompt_token_ids}
elif isinstance(inputs, TextOnlyInputs):
_inputs: Dict[str, Any] = {
"prompt_token_ids": inputs.prompt_token_ids
}

if inputs.prompt is not None:
_inputs["prompt"] = inputs.prompt

inputs = _inputs

elif isinstance(inputs, dict):
if "prompt" not in inputs and "prompt_token_ids" not in inputs:
raise ValidationError('"prompt" and "prompt_token_ids" '
'have at least one in inputs.')
inputs = {
k: v
for k, v in inputs.items()
if k in {"prompt", "prompt_token_ids"}
}
else:
raise ValidationError(
f"Input does not support {type(inputs)} data type")

if not arrival_time:
arrival_time = time.time()
request = TextRequest(request_id=str(request_id),
inputs=inputs,
arrival_time=arrival_time)
return request

@classmethod
def from_engine(cls, engine):
return cls()


class RequestProcessor(ABC):
"""
Request -> RequestProcessor -> SchedulableRequest
"""

@abstractmethod
def __call__(self, request: Request) -> SchedulableRequest:
raise NotImplementedError

@classmethod
@abstractmethod
def from_engine(cls, engine):
raise NotImplementedError


class TextRequestProcessor(RequestProcessor):

def __init__(self, tokenizer: Tokenizer):
self.tokenizer = tokenizer

def __call__(self, request: Request) -> TextSchedulableRequest:
assert isinstance(request, TextRequest)

request = cast(TextRequest, request)

inputs = request.inputs

if "prompt_token_ids" not in inputs:
tokenizer = self.tokenizer

prompt_token_ids = tokenizer.encode(inputs["prompt"])
else:
prompt_token_ids = inputs["prompt_token_ids"]

schedulable_request = TextSchedulableRequest(
request_id=request.request_id,
inputs=TextOnlyInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt")),
arrival_time=request.arrival_time)

return schedulable_request

@classmethod
def from_engine(cls, engine):
return cls(engine.tokenizer)
32 changes: 32 additions & 0 deletions vllm/inputs/prefill_only/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from vllm.transformers_utils.tokenizer import get_tokenizer


class Tokenizer:

def __init__(self, tokenizer_name: str, **kwargs):
self.tokenizer_name = tokenizer_name
self.tokenizer_kwargs = kwargs

self.tokenizer = get_tokenizer(tokenizer_name=self.tokenizer_name,
**self.tokenizer_kwargs)

@classmethod
def from_engine(cls, engine):
init_kwargs = dict(
tokenizer_name=engine.engine_config.model_config.tokenizer,
tokenizer_mode=engine.engine_config.model_config.tokenizer_mode,
trust_remote_code=engine.engine_config.model_config.
trust_remote_code,
revision=engine.engine_config.model_config.tokenizer_revision)

return cls(**init_kwargs)

def __call__(self, *args, **kwargs):
return self.tokenizer(*args, **kwargs)

def encode(self, *args, **kwargs):
return self.tokenizer.encode(*args, **kwargs)

@property
def eos_token_id(self):
return self.tokenizer.eos_token_id

0 comments on commit e80d5f2

Please sign in to comment.