Skip to content

Commit

Permalink
Add wrapper drafts
Browse files Browse the repository at this point in the history
  • Loading branch information
Rita authored and Rita committed Feb 21, 2025
1 parent 1dabe91 commit 784997f
Show file tree
Hide file tree
Showing 6 changed files with 916 additions and 6 deletions.
763 changes: 760 additions & 3 deletions pdm.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ all = [
"monai>=1.3.2",
"datasets>=3.2.0",
"litellm>=1.61.8",
"vllm>=0.5.1",
]

[project.scripts]
Expand Down
9 changes: 7 additions & 2 deletions src/eva/language/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

from eva.language.models import networks, wrappers
from eva.language.models.networks import TextModule
from eva.language.models.wrappers import HuggingFaceTextModel, LiteLLMTextModel
from eva.language.models.wrappers import HuggingFaceTextModel, LiteLLMTextModel, VLLMTextModel

__all__ = ["networks", "wrappers", "TextModule", "HuggingFaceTextModel", "LiteLLMTextModel"]
__all__ = ["networks",
"wrappers",
"TextModule",
"HuggingFaceTextModel",
"LiteLLMTextModel",
'VLLMTextModel']
3 changes: 2 additions & 1 deletion src/eva/language/models/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@

from eva.language.models.wrappers.huggingface import HuggingFaceTextModel
from eva.language.models.wrappers.litellm import LiteLLMTextModel
from eva.language.models.wrappers.vllm import VLLMTextModel

__all__ = ["HuggingFaceTextModel", "LiteLLMTextModel"]
__all__ = ["HuggingFaceTextModel", "LiteLLMTextModel", "VLLMTextModel"]
56 changes: 56 additions & 0 deletions src/eva/language/models/wrappers/litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""LLM wrapper for litellm models."""

from typing import Any, Dict, Optional

from litellm import completion
from typing_extensions import override

from eva.core.models.wrappers import base


class LiteLLMTextModel(base.BaseModel):
"""Wrapper class for using litellm for chat-based text generation.
This wrapper uses litellm's `completion` function which accepts a list of
message dictionaries. The `generate` method converts a string prompt into a chat
message with a default role of "user", optionally prepends a system message, and
includes an API key if provided.
"""

def __init__(
self,
model_name_or_path: str,
) -> None:
"""Initializes the litellm chat model wrapper.
Args:
model_name_or_path: The model identifier (or name) for litellm (e.g.,
"openai/gpt-4o" or "anthropic/claude-3-sonnet-20240229").
"""
super().__init__()
self._model_name_or_path = model_name_or_path
self.load_model()

@override
def load_model(self) -> None:
"""Prepares the litellm model.
Note:
litellm does not require an explicit loading step; models are invoked
directly during generation. This method exists for API consistency.
"""
pass

def generate(self, prompt: str, **generate_kwargs) -> str:
"""Generates text using litellm.
Args:
prompt: A string prompt that will be converted into a "user" chat message.
generate_kwargs: Additional parameters for generation (e.g., max_tokens).
Returns:
The generated text response.
"""
messages = [{"role": "user", "content": prompt}]
response = completion(model=self._model_name_or_path, messages=messages, **generate_kwargs)
return response["choices"][0]["message"]["content"]
90 changes: 90 additions & 0 deletions src/eva/language/models/wrappers/vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""LLM wrapper for vLLM models."""

from typing import Any, Dict
import time

from typing_extensions import override

from vllm import LLM, SamplingParams
from vllm.inputs import TokensPrompt

from eva.core.models.wrappers import base


class VLLMTextModel(base.BaseModel):
"""
Wrapper class for using vLLM for text generation.
This wrapper loads a vLLM model, sets up the tokenizer and sampling parameters,
and uses a chat template to convert a plain string prompt into the proper input
format for vLLM generation. It then returns the generated text response.
"""

def __init__(
self,
model_name_or_path: str,
model_kwargs: Dict[str, Any] | None = None,
) -> None:
"""
Initializes the vLLM model wrapper.
Args:
model_name_or_path: The model identifier (e.g., a Hugging Face repo ID or local path).
model_kwargs: Additional keyword arguments for initializing the vLLM model.
generation_kwargs: Additional keyword arguments for the sampling parameters.
"""
super().__init__()
self._model_name_or_path = model_name_or_path
self._model_kwargs = model_kwargs or {}
self.load_model()

@override
def load_model(self) -> None:
"""
Loads the vLLM model and sets up the tokenizer and sampling parameters.
"""
self._model = LLM(model=self._model_name_or_path, **self._model_kwargs)
self._tokenizer = self._model.get_tokenizer()

def _apply_chat_template(self, prompt: str) -> TokensPrompt:
"""
Converts a prompt string into a TokensPrompt using the tokenizer's chat template.
Args:
prompt: The input prompt as a string.
Returns:
A TokensPrompt object ready for generation.
Raises:
ValueError: If the tokenizer does not support a chat template.
"""
messages = [{"role": "user", "content": prompt}]
if self._tokenizer.chat_template is None:
raise ValueError("Tokenizer does not have a chat template.")
encoded_messages = self._tokenizer.apply_chat_template(
[messages],
tokenize=True,
add_generation_prompt=True,
)
if len(encoded_messages[0]) >= 2 and (
encoded_messages[0][0] == self._tokenizer.bos_token_id
and encoded_messages[0][1] == self._tokenizer.bos_token_id
):
encoded_messages[0] = encoded_messages[0][1:]
return [TokensPrompt(prompt_token_ids=encoded_messages[0])]

def generate(self, prompt: str, **generate_kwargs) -> str:
"""
Generates text for the given prompt using the vLLM model.
Args:
prompt: A string prompt for generation.
generate_kwargs: Additional parameters for generation (e.g., max_tokens).
Returns:
The generated text response.
"""
tokens_prompt = self._apply_chat_template(prompt)
outputs = self._model.generate(tokens_prompt, SamplingParams(**generate_kwargs))
return outputs[0].outputs[0].text

0 comments on commit 784997f

Please sign in to comment.