Skip to content

Commit

Permalink
revet to main
Browse files Browse the repository at this point in the history
  • Loading branch information
Safoora Yousefi committed Dec 18, 2024
1 parent 57ab426 commit 0544d62
Showing 1 changed file with 37 additions and 71 deletions.
108 changes: 37 additions & 71 deletions eureka_ml_insights/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from eureka_ml_insights.secret_management import get_secret


@dataclass
class Model(ABC):
"""This class is used to define the structure of a model class.
Expand Down Expand Up @@ -90,15 +91,15 @@ class EndpointModel(Model):
num_retries: int = 3

@abstractmethod
def create_request(self, text_prompt, **kwargs):
def create_request(self, text_prompt, query_images=None, system_message=None):
raise NotImplementedError

@abstractmethod
def get_response(self, request):
# must return the model output and the response time
raise NotImplementedError

def generate(self, query_text, **kwargs):
def generate(self, query_text, query_images=None, system_message=None):
"""
Calls the endpoint to generate the model response.
args:
Expand All @@ -110,7 +111,7 @@ def generate(self, query_text, **kwargs):
and any other relevant information returned by the model.
"""
response_dict = {}
request = self.create_request(query_text, **kwargs)
request = self.create_request(query_text, query_images=query_images, system_message=system_message)
attempts = 0
while attempts < self.num_retries:
try:
Expand Down Expand Up @@ -158,17 +159,15 @@ class RestEndpointModel(EndpointModel, KeyBasedAuthMixIn):
presence_penalty: float = 0
do_sample: bool = True

def create_request(self, text_prompt, query_images=None, system_message=None, previous_messages=None):
"""Creates a request for the model."""
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
if previous_messages:
messages.extend(previous_messages)
messages.append({"role": "user", "content": text_prompt})
def create_request(self, text_prompt, query_images=None, system_message=None):
data = {
"input_data": {
"input_string": messages,
"input_string": [
{
"role": "user",
"content": text_prompt,
}
],
"parameters": {
"temperature": self.temperature,
"top_p": self.top_p,
Expand All @@ -177,8 +176,12 @@ def create_request(self, text_prompt, query_images=None, system_message=None, pr
},
}
}
if system_message:
data["input_data"]["input_string"] = [{"role": "system", "content": system_message}] + data["input_data"][
"input_string"
]
if query_images:
raise NotImplementedError("Images are not supported for RestEndpointModel endpoints yet.")
raise NotImplementedError("Images are not supported for GCR endpoints yet.")

body = str.encode(json.dumps(data))
# The azureml-model-deployment header will force the request to go to a specific deployment.
Expand Down Expand Up @@ -217,7 +220,6 @@ class ServerlessAzureRestEndpointModel(EndpointModel, KeyBasedAuthMixIn):
url: str = None
model_name: str = None
stream: bool = False
auth_scope: str = "https://cognitiveservices.azure.com/.default"

def __post_init__(self):
try:
Expand All @@ -233,7 +235,7 @@ def __post_init__(self):
}
except ValueError:
self.bearer_token_provider = get_bearer_token_provider(
DefaultAzureCredential(), self.auth_scope
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
)
self.headers = {
"Content-Type": "application/json",
Expand All @@ -246,7 +248,7 @@ def __post_init__(self):
}

@abstractmethod
def create_request(self, text_prompt, query_images=None, system_message=None, previous_messages=None):
def create_request(self, text_prompt, query_images=None, system_message=None):
# Exact model parameters are model-specific.
# The method cannot be implemented unless the model being deployed is known.
raise NotImplementedError
Expand Down Expand Up @@ -286,18 +288,13 @@ class LlamaServerlessAzureRestEndpointModel(ServerlessAzureRestEndpointModel):
skip_special_tokens: bool = False
ignore_eos: bool = False

def create_request(self, text_prompt, query_images=None, system_message=None, previous_messages=None):
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
if previous_messages:
messages.extend(previous_messages)
user_content = text_prompt
def create_request(self, text_prompt, query_images=None, *args, **kwargs):
user_content = {"role": "user", "content": text_prompt}
if query_images:
if len(query_images) > 1:
raise ValueError("Llama vision model does not support more than 1 image.")
encoded_images = self.base64encode(query_images)
user_content = [
user_content["content"] = [
{"type": "text", "text": text_prompt},
{
"type": "image_url",
Expand All @@ -306,11 +303,9 @@ def create_request(self, text_prompt, query_images=None, system_message=None, pr
},
},
]
messages.append({"role": "user", "content": user_content})


data = {
"messages": messages,
"messages": [user_content],
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
Expand Down Expand Up @@ -342,17 +337,9 @@ def __post_init__(self):
self.top_p = 1
super().__post_init__()

def create_request(self, text_prompt, query_images=None, system_message=None, previous_messages=None):
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
if previous_messages:
messages.extend(previous_messages)
if query_images:
raise NotImplementedError("Images are not supported for MistralServerlessAzureRestEndpointModel endpoints.")
messages.append({"role": "user", "content": text_prompt})
def create_request(self, text_prompt, *args, **kwargs):
data = {
"messages": messages,
"messages": [{"role": "user", "content": text_prompt}],
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
Expand All @@ -371,16 +358,14 @@ class OpenAICommonRequestResponseMixIn:
This mixin class defines the request and response handling for most OpenAI models.
"""

def create_request(self, prompt, query_images=None, system_message=None, previous_messages=None):
def create_request(self, prompt, query_images=None, system_message=None):
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
if previous_messages:
messages.extend(previous_messages)
user_content = prompt
user_content = {"role": "user", "content": prompt}
if query_images:
encoded_images = self.base64encode(query_images)
user_content = [
user_content["content"] = [
{"type": "text", "text": prompt},
{
"type": "image_url",
Expand All @@ -389,7 +374,7 @@ def create_request(self, prompt, query_images=None, system_message=None, previou
},
},
]
messages.append({"role": "user", "content": user_content})
messages.append(user_content)
return {"messages": messages}

def get_response(self, request):
Expand Down Expand Up @@ -419,7 +404,7 @@ def get_client(self):
from openai import AzureOpenAI

token_provider = get_bearer_token_provider(
DefaultAzureCredential(), self.auth_scope
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
)
return AzureOpenAI(
azure_endpoint=self.url,
Expand Down Expand Up @@ -464,7 +449,6 @@ class AzureOpenAIModel(OpenAICommonRequestResponseMixIn, AzureOpenAIClientMixIn,
presence_penalty: float = 0
seed: int = 0
api_version: str = "2023-06-01-preview"
auth_scope: str = "https://cognitiveservices.azure.com/.default"

def __post_init__(self):
self.client = self.get_client()
Expand All @@ -489,17 +473,8 @@ def __post_init__(self):


class OpenAIO1RequestResponseMixIn:

def create_request(self, prompt, query_images=None, system_message=None, previous_messages=None):
if system_message:
# system messages are not supported for OAI reasoning models
# https://platform.openai.com/docs/guides/reasoning
logging.warning("System messages are not supported for OAI reasoning models.")
messages = []
if previous_messages:
messages.extend(previous_messages)

messages.append({"role": "user", "content": prompt})
def create_request(self, prompt, *args, **kwargs):
messages = [{"role": "user", "content": prompt}]
return {"messages": messages}

def get_response(self, request):
Expand Down Expand Up @@ -553,8 +528,6 @@ class AzureOpenAIO1Model(OpenAIO1RequestResponseMixIn, AzureOpenAIClientMixIn, E
frequency_penalty: float = 0
presence_penalty: float = 0
api_version: str = "2023-06-01-preview"
auth_scope: str = "https://cognitiveservices.azure.com/.default"


def __post_init__(self):
self.client = self.get_client()
Expand Down Expand Up @@ -590,13 +563,7 @@ def __post_init__(self):
def create_request(self, text_prompt, query_images=None, system_message=None):
import google.generativeai as genai

if self.model_name == "gemini-1.0-pro":
if system_message:
logging.warning("System messages are not supported for Gemini 1.0 Pro.")
self.model = genai.GenerativeModel(self.model_name)
else:
self.model = genai.GenerativeModel(self.model_name, system_instruction=system_message)

self.model = genai.GenerativeModel(self.model_name, system_instruction=system_message)
if query_images:
return [text_prompt] + query_images
else:
Expand Down Expand Up @@ -975,15 +942,14 @@ def __post_init__(self):
timeout=self.timeout,
)

def create_request(self, prompt, query_images=None, system_message=None, previous_messages=None):
def create_request(self, prompt, query_images=None, system_message=None):
messages = []

user_content = prompt
if previous_messages:
messages.extend(previous_messages)
user_content = {"role": "user", "content": prompt}

if query_images:
encoded_images = self.base64encode(query_images)
user_content = [
user_content["content"] = [
{"type": "text", "text": prompt},
{
"type": "image",
Expand All @@ -994,7 +960,7 @@ def create_request(self, prompt, query_images=None, system_message=None, previou
},
},
]
messages.append({"role": "user", "content": user_content})
messages.append(user_content)

if system_message:
return {"messages": messages, "system": system_message}
Expand Down

0 comments on commit 0544d62

Please sign in to comment.