-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Rita
authored and
Rita
committed
Feb 21, 2025
1 parent
1dabe91
commit 784997f
Showing
6 changed files
with
916 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -85,6 +85,7 @@ all = [ | |
"monai>=1.3.2", | ||
"datasets>=3.2.0", | ||
"litellm>=1.61.8", | ||
"vllm>=0.5.1", | ||
] | ||
|
||
[project.scripts] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
"""LLM wrapper for litellm models.""" | ||
|
||
from typing import Any, Dict, Optional | ||
|
||
from litellm import completion | ||
from typing_extensions import override | ||
|
||
from eva.core.models.wrappers import base | ||
|
||
|
||
class LiteLLMTextModel(base.BaseModel): | ||
"""Wrapper class for using litellm for chat-based text generation. | ||
This wrapper uses litellm's `completion` function which accepts a list of | ||
message dictionaries. The `generate` method converts a string prompt into a chat | ||
message with a default role of "user", optionally prepends a system message, and | ||
includes an API key if provided. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model_name_or_path: str, | ||
) -> None: | ||
"""Initializes the litellm chat model wrapper. | ||
Args: | ||
model_name_or_path: The model identifier (or name) for litellm (e.g., | ||
"openai/gpt-4o" or "anthropic/claude-3-sonnet-20240229"). | ||
""" | ||
super().__init__() | ||
self._model_name_or_path = model_name_or_path | ||
self.load_model() | ||
|
||
@override | ||
def load_model(self) -> None: | ||
"""Prepares the litellm model. | ||
Note: | ||
litellm does not require an explicit loading step; models are invoked | ||
directly during generation. This method exists for API consistency. | ||
""" | ||
pass | ||
|
||
def generate(self, prompt: str, **generate_kwargs) -> str: | ||
"""Generates text using litellm. | ||
Args: | ||
prompt: A string prompt that will be converted into a "user" chat message. | ||
generate_kwargs: Additional parameters for generation (e.g., max_tokens). | ||
Returns: | ||
The generated text response. | ||
""" | ||
messages = [{"role": "user", "content": prompt}] | ||
response = completion(model=self._model_name_or_path, messages=messages, **generate_kwargs) | ||
return response["choices"][0]["message"]["content"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
"""LLM wrapper for vLLM models.""" | ||
|
||
from typing import Any, Dict | ||
import time | ||
|
||
from typing_extensions import override | ||
|
||
from vllm import LLM, SamplingParams | ||
from vllm.inputs import TokensPrompt | ||
|
||
from eva.core.models.wrappers import base | ||
|
||
|
||
class VLLMTextModel(base.BaseModel): | ||
""" | ||
Wrapper class for using vLLM for text generation. | ||
This wrapper loads a vLLM model, sets up the tokenizer and sampling parameters, | ||
and uses a chat template to convert a plain string prompt into the proper input | ||
format for vLLM generation. It then returns the generated text response. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model_name_or_path: str, | ||
model_kwargs: Dict[str, Any] | None = None, | ||
) -> None: | ||
""" | ||
Initializes the vLLM model wrapper. | ||
Args: | ||
model_name_or_path: The model identifier (e.g., a Hugging Face repo ID or local path). | ||
model_kwargs: Additional keyword arguments for initializing the vLLM model. | ||
generation_kwargs: Additional keyword arguments for the sampling parameters. | ||
""" | ||
super().__init__() | ||
self._model_name_or_path = model_name_or_path | ||
self._model_kwargs = model_kwargs or {} | ||
self.load_model() | ||
|
||
@override | ||
def load_model(self) -> None: | ||
""" | ||
Loads the vLLM model and sets up the tokenizer and sampling parameters. | ||
""" | ||
self._model = LLM(model=self._model_name_or_path, **self._model_kwargs) | ||
self._tokenizer = self._model.get_tokenizer() | ||
|
||
def _apply_chat_template(self, prompt: str) -> TokensPrompt: | ||
""" | ||
Converts a prompt string into a TokensPrompt using the tokenizer's chat template. | ||
Args: | ||
prompt: The input prompt as a string. | ||
Returns: | ||
A TokensPrompt object ready for generation. | ||
Raises: | ||
ValueError: If the tokenizer does not support a chat template. | ||
""" | ||
messages = [{"role": "user", "content": prompt}] | ||
if self._tokenizer.chat_template is None: | ||
raise ValueError("Tokenizer does not have a chat template.") | ||
encoded_messages = self._tokenizer.apply_chat_template( | ||
[messages], | ||
tokenize=True, | ||
add_generation_prompt=True, | ||
) | ||
if len(encoded_messages[0]) >= 2 and ( | ||
encoded_messages[0][0] == self._tokenizer.bos_token_id | ||
and encoded_messages[0][1] == self._tokenizer.bos_token_id | ||
): | ||
encoded_messages[0] = encoded_messages[0][1:] | ||
return [TokensPrompt(prompt_token_ids=encoded_messages[0])] | ||
|
||
def generate(self, prompt: str, **generate_kwargs) -> str: | ||
""" | ||
Generates text for the given prompt using the vLLM model. | ||
Args: | ||
prompt: A string prompt for generation. | ||
generate_kwargs: Additional parameters for generation (e.g., max_tokens). | ||
Returns: | ||
The generated text response. | ||
""" | ||
tokens_prompt = self._apply_chat_template(prompt) | ||
outputs = self._model.generate(tokens_prompt, SamplingParams(**generate_kwargs)) | ||
return outputs[0].outputs[0].text |