diff --git a/README.md b/README.md index 23706533..15cc2cac 100644 --- a/README.md +++ b/README.md @@ -10,11 +10,12 @@ Includes: - Deployable on any Kubernetes cluster, with its Helm chart - Manage users effortlessly with OpenID Connect - More than 150 tones and personalities (accountant, advisor, debater, excel sheet, instructor, logistician, etc.) to better help employees in their specific daily tasks -- Plug and play with any storage system, including [Azure Cosmos DB](https://learn.microsoft.com/en-us/azure/cosmos-db/), [Redis](https://github.com/redis/redis) and [Qdrant](https://github.com/qdrant/qdrant). +- Plug and play storage system, including [Azure Cosmos DB](https://learn.microsoft.com/en-us/azure/cosmos-db/), [Redis](https://github.com/redis/redis) and [Qdrant](https://github.com/qdrant/qdrant). - Possibility to send temporary messages, for confidentiality - Salable system based on stateless APIs, cache, progressive web app and events - Search engine for conversations, based on semantic similarity and AI embeddings -- Unlimited conversation history +- Unlimited conversation history and number of users +- Usage tracking, for better understanding of your employees' usage ![Application screenshot](docs/main.png) @@ -35,6 +36,9 @@ store = "cosmos" # Enum: "redis" stream = "redis" +[api] +root_path = "" + [openai] ada_deploy_id = "ada" ada_max_tokens = 2049 @@ -49,7 +53,7 @@ max_length = 1000 [logging] app_level = "DEBUG" -sys_level = "INFO" +sys_level = "WARN" [oidc] algorithms = ["RS256"] @@ -65,7 +69,7 @@ db = 0 host = "localhost" [cosmos] -# Containers "conversation" (/user_id), "message" (/conversation_id) and "user" (/dummy) must exist +# Containers "conversation" (/user_id), "message" (/conversation_id), "user" (/dummy), "usage" (/user_id) must exist url = "https://private-gpt.documents.azure.com:443" database = "private-gpt" ``` diff --git a/cicd/helm/private-gpt/templates/conversation-api-config.yaml b/cicd/helm/private-gpt/templates/conversation-api-config.yaml index 5be68d57..1fd11eac 100644 --- a/cicd/helm/private-gpt/templates/conversation-api-config.yaml +++ b/cicd/helm/private-gpt/templates/conversation-api-config.yaml @@ -7,6 +7,11 @@ metadata: app.kubernetes.io/component: conversation-api data: config.toml: | + [persistence] + search = "qdrant" + store = "cosmos" + stream = "redis" + [api] root_path = "/{{ include "private-gpt.fullname" . }}-conversation-api" @@ -37,3 +42,7 @@ data: [redis] db = {{ .Values.redis.db | int }} host = "{{ include "common.names.fullname" .Subcharts.redis }}-master" + + [cosmos] + url = {{ .Values.cosmos.url | quote }} + database = {{ .Values.cosmos.database | quote }} diff --git a/cicd/helm/private-gpt/values.yaml b/cicd/helm/private-gpt/values.yaml index 04ac7fb6..b2f1b3b5 100644 --- a/cicd/helm/private-gpt/values.yaml +++ b/cicd/helm/private-gpt/values.yaml @@ -42,6 +42,11 @@ api: base: null gpt_deploy_id: gpt-35-turbo +cosmos: + # https://[db].documents.azure.com + url: null + database: null + redis: auth: enabled: false diff --git a/src/conversation-api/ai/contentsafety.py b/src/conversation-api/ai/contentsafety.py new file mode 100644 index 00000000..ffbde09b --- /dev/null +++ b/src/conversation-api/ai/contentsafety.py @@ -0,0 +1,79 @@ +# Import utils +from utils import (build_logger, get_config) + +# Import misc +from azure.core.credentials import AzureKeyCredential +from fastapi import HTTPException, status +from tenacity import retry, stop_after_attempt, wait_random_exponential +import azure.ai.contentsafety as azure_cs +import azure.core.exceptions as azure_exceptions + + +### +# Init misc +### + +logger = build_logger(__name__) + +### +# Init Azure Content Safety +### + +# Score are following: 0 - Safe, 2 - Low, 4 - Medium, 6 - High +# See: https://review.learn.microsoft.com/en-us/azure/cognitive-services/content-safety/concepts/harm-categories?branch=release-build-content-safety#severity-levels +ACS_SEVERITY_THRESHOLD = 2 +ACS_API_BASE = get_config("acs", "api_base", str, required=True) +ACS_API_TOKEN = get_config("acs", "api_token", str, required=True) +ACS_MAX_LENGTH = get_config("acs", "max_length", int, required=True) +logger.info(f"Connected Azure Content Safety to {ACS_API_BASE}") +acs_client = azure_cs.ContentSafetyClient( + ACS_API_BASE, AzureKeyCredential(ACS_API_TOKEN) +) + + +class ContentSafety: + @retry( + reraise=True, + stop=stop_after_attempt(3), + wait=wait_random_exponential(multiplier=0.5, max=30), + ) + async def is_moderated(self, prompt: str) -> bool: + logger.debug(f"Checking moderation for text: {prompt}") + + if len(prompt) > ACS_MAX_LENGTH: + logger.info(f"Message ({len(prompt)}) too long for moderation") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Message too long", + ) + + req = azure_cs.models.AnalyzeTextOptions( + text=prompt, + categories=[ + azure_cs.models.TextCategory.HATE, + azure_cs.models.TextCategory.SELF_HARM, + azure_cs.models.TextCategory.SEXUAL, + azure_cs.models.TextCategory.VIOLENCE, + ], + ) + + try: + res = acs_client.analyze_text(req) + except azure_exceptions.ClientAuthenticationError as e: + logger.exception(e) + return False + + is_moderated = any( + cat.severity >= ACS_SEVERITY_THRESHOLD + for cat in [ + res.hate_result, + res.self_harm_result, + res.sexual_result, + res.violence_result, + ] + ) + if is_moderated: + logger.info(f"Message is moderated: {prompt}") + logger.debug(f"Moderation result: {res}") + + return is_moderated diff --git a/src/conversation-api/ai/openai.py b/src/conversation-api/ai/openai.py new file mode 100644 index 00000000..766d2294 --- /dev/null +++ b/src/conversation-api/ai/openai.py @@ -0,0 +1,134 @@ +# Import utils +from uuid import UUID +from utils import (build_logger, get_config, hash_token) + +# Import misc +from azure.identity import DefaultAzureCredential +from models.user import UserModel +from tenacity import retry, stop_after_attempt, wait_random_exponential +from typing import Any, Dict, List, AsyncGenerator, Union +import asyncio +import openai + + +### +# Init misc +### + +logger = build_logger(__name__) +loop = asyncio.get_running_loop() + + +### +# Init OpenIA +### + +async def refresh_oai_token_background(): + """ + Refresh OpenAI token every 15 minutes. + + The OpenAI SDK does not support token refresh, so we need to do it manually. We passe manually the token to the SDK. Azure AD tokens are valid for 30 mins, but we refresh every 15 minutes to be safe. + + See: https://github.com/openai/openai-python/pull/350#issuecomment-1489813285 + """ + while True: + logger.info("Refreshing OpenAI token") + oai_cred = DefaultAzureCredential() + oai_token = oai_cred.get_token("https://cognitiveservices.azure.com/.default") + openai.api_key = oai_token.token + # Execute every 20 minutes + await asyncio.sleep(15 * 60) + + +openai.api_base = get_config("openai", "api_base", str, required=True) +openai.api_type = "azure_ad" +openai.api_version = "2023-05-15" +logger.info(f"Using Aure private service ({openai.api_base})") +loop.create_task(refresh_oai_token_background()) + +OAI_GPT_DEPLOY_ID = get_config("openai", "gpt_deploy_id", str, required=True) +OAI_GPT_MAX_TOKENS = get_config("openai", "gpt_max_tokens", int, required=True) +OAI_GPT_MODEL = get_config( + "openai", "gpt_model", str, default="gpt-3.5-turbo", required=True +) +logger.info( + f'Using OpenAI ADA model "{OAI_GPT_MODEL}" ({OAI_GPT_DEPLOY_ID}) with {OAI_GPT_MAX_TOKENS} tokens max' +) + +OAI_ADA_DEPLOY_ID = get_config("openai", "ada_deploy_id", str, required=True) +OAI_ADA_MAX_TOKENS = get_config("openai", "ada_max_tokens", int, required=True) +OAI_ADA_MODEL = get_config( + "openai", "ada_model", str, default="text-embedding-ada-002", required=True +) +logger.info( + f'Using OpenAI ADA model "{OAI_ADA_MODEL}" ({OAI_ADA_DEPLOY_ID}) with {OAI_ADA_MAX_TOKENS} tokens max' +) + + +class OpenAI: + @retry( + reraise=True, + stop=stop_after_attempt(3), + wait=wait_random_exponential(multiplier=0.5, max=30), + ) + async def vector_from_text(self, prompt: str, user_id: UUID) -> List[float]: + logger.debug(f"Getting vector for text: {prompt}") + try: + res = openai.Embedding.create( + deployment_id=OAI_ADA_DEPLOY_ID, + input=prompt, + model=OAI_ADA_MODEL, + user=user_id.hex, + ) + except openai.error.AuthenticationError as e: + logger.exception(e) + return [] + + return res.data[0].embedding + + @retry( + reraise=True, + stop=stop_after_attempt(3), + wait=wait_random_exponential(multiplier=0.5, max=30), + ) + async def completion(self, messages: List[Dict[str, str]], current_user: UserModel) -> Union[str, None]: + try: + # Use chat completion to get a more natural response and lower the usage cost + completion = openai.ChatCompletion.create( + deployment_id=OAI_GPT_DEPLOY_ID, + messages=messages, + model=OAI_GPT_MODEL, + presence_penalty=1, # Increase the model's likelihood to talk about new topics + user=hash_token(current_user.id.bytes).hex, + ) + content = completion["choices"][0].message.content + except openai.error.AuthenticationError as e: + logger.exception(e) + return + + return content + + @retry( + reraise=True, + stop=stop_after_attempt(3), + wait=wait_random_exponential(multiplier=0.5, max=30), + ) + async def completion_stream(self, messages: List[Dict[str, str]], current_user: UserModel) -> AsyncGenerator[Any, None]: + try: + # Use chat completion to get a more natural response and lower the usage cost + chunks = openai.ChatCompletion.create( + deployment_id=OAI_GPT_DEPLOY_ID, + messages=messages, + model=OAI_GPT_MODEL, + presence_penalty=1, # Increase the model's likelihood to talk about new topics + stream=True, + user=hash_token(current_user.id.bytes).hex, + ) + except openai.error.AuthenticationError as e: + logger.exception(e) + return + + for chunk in chunks: + content = chunk["choices"][0].get("delta", {}).get("content") + if content is not None: + yield content diff --git a/src/conversation-api/main.py b/src/conversation-api/main.py index 48ff6bcc..2c1eb605 100644 --- a/src/conversation-api/main.py +++ b/src/conversation-api/main.py @@ -9,8 +9,8 @@ ) # Import misc -from azure.core.credentials import AzureKeyCredential -from azure.identity import DefaultAzureCredential +from ai.contentsafety import ContentSafety +from ai.openai import OpenAI, OAI_GPT_MODEL, OAI_GPT_MAX_TOKENS, OAI_ADA_MODEL, OAI_ADA_MAX_TOKENS from datetime import datetime from fastapi import FastAPI, HTTPException, status, Request, Depends from fastapi.middleware.cors import CORSMiddleware @@ -19,20 +19,17 @@ from models.message import MessageModel, MessageRole, StoredMessageModel from models.prompt import StoredPromptModel, ListPromptsModel from models.search import SearchModel +from models.usage import UsageModel from models.user import UserModel from persistence.isearch import SearchImplementation from persistence.istore import StoreImplementation from persistence.istream import StreamImplementation from sse_starlette.sse import EventSourceResponse -from tenacity import retry, stop_after_attempt, wait_random_exponential from typing import Annotated, Dict, List, Optional from uuid import UUID from uuid import uuid4 import asyncio -import azure.ai.contentsafety as azure_cs -import azure.core.exceptions as azure_exceptions import csv -import openai ### @@ -40,10 +37,12 @@ ### logger = build_logger(__name__) +loop = asyncio.get_running_loop() ### # Init persistence ### + store_impl = get_config("persistence", "store", StoreImplementation, required=True) if store_impl == StoreImplementation.COSMOS: logger.info("Using CosmosDB store") @@ -72,58 +71,6 @@ else: raise ValueError(f"Unknown stream implementation: {stream_impl}") -### -# Init OpenAI -### - - -async def refresh_oai_token(): - """ - Refresh OpenAI token every 15 minutes. - - The OpenAI SDK does not support token refresh, so we need to do it manually. We passe manually the token to the SDK. Azure AD tokens are valid for 30 mins, but we refresh every 15 minutes to be safe. - - See: https://github.com/openai/openai-python/pull/350#issuecomment-1489813285 - """ - while True: - logger.info("Refreshing OpenAI token") - oai_cred = DefaultAzureCredential() - oai_token = oai_cred.get_token("https://cognitiveservices.azure.com/.default") - openai.api_key = oai_token.token - # Execute every 20 minutes - await asyncio.sleep(15 * 60) - - -OAI_GPT_DEPLOY_ID = get_config("openai", "gpt_deploy_id", str, required=True) -OAI_GPT_MAX_TOKENS = get_config("openai", "gpt_max_tokens", int, required=True) -OAI_GPT_MODEL = get_config( - "openai", "gpt_model", str, default="gpt-3.5-turbo", required=True -) -logger.info( - f'Using OpenAI ADA model "{OAI_GPT_MODEL}" ({OAI_GPT_DEPLOY_ID}) with {OAI_GPT_MAX_TOKENS} tokens max' -) - -openai.api_base = get_config("openai", "api_base", str, required=True) -openai.api_type = "azure_ad" -openai.api_version = "2023-05-15" -logger.info(f"Using Aure private service ({openai.api_base})") -asyncio.create_task(refresh_oai_token()) - -### -# Init Azure Content Safety -### - -# Score are following: 0 - Safe, 2 - Low, 4 - Medium, 6 - High -# See: https://review.learn.microsoft.com/en-us/azure/cognitive-services/content-safety/concepts/harm-categories?branch=release-build-content-safety#severity-levels -ACS_SEVERITY_THRESHOLD = 2 -ACS_API_BASE = get_config("acs", "api_base", str, required=True) -ACS_API_TOKEN = get_config("acs", "api_token", str, required=True) -ACS_MAX_LENGTH = get_config("acs", "max_length", int, required=True) -logger.info(f"Connected Azure Content Safety to {ACS_API_BASE}") -acs_client = azure_cs.ContentSafetyClient( - ACS_API_BASE, AzureKeyCredential(ACS_API_TOKEN) -) - ### # Init FastAPI ### @@ -159,6 +106,8 @@ async def refresh_oai_token(): # Init Generative AI ### +openai = OpenAI() +content_safety = ContentSafety() def get_ai_prompt() -> Dict[UUID, StoredPromptModel]: prompts = {} @@ -181,7 +130,7 @@ def get_ai_prompt() -> Dict[UUID, StoredPromptModel]: AI_PROMPTS = get_ai_prompt() AI_CONVERSATION_DEFAULT_PROMPT = f""" -Today, we are the {datetime.now()}. +Today, we are the {datetime.utcnow()}. You MUST: - Cite sources and examples as footnotes (example: [^1]) @@ -262,7 +211,7 @@ async def get_current_user( raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) user = store.user_get(sub) - logger.info(f"User logged in: {user}") + logger.info(f'User "{user.id}" logged in') logger.debug(f"JWT: {jwt}") if user: return user @@ -312,7 +261,7 @@ async def message_post( conversation_id: Optional[UUID] = None, prompt_id: Optional[UUID] = None, ) -> GetConversationModel: - if await is_moderated(content): + if await content_safety.is_moderated(content): logger.info(f"Message content is moderated: {content}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -341,31 +290,17 @@ async def message_post( message = StoredMessageModel( content=content, conversation_id=conversation_id, - created_at=datetime.now(), + created_at=datetime.utcnow(), id=uuid4(), role=MessageRole.USER, secret=secret, token=uuid4(), ) - # Validate message length - tokens_nb = oai_tokens_nb( - message.content - + "".join([m.content for m in store.message_list(message.conversation_id)]), - OAI_GPT_MODEL, - ) - - logger.debug(f"{tokens_nb} tokens in the conversation") - if tokens_nb > OAI_GPT_MAX_TOKENS: - logger.info(f"Message ({tokens_nb}) too long for conversation") - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Conversation history is too long", - ) + tokens_nb = await _validate_message_length(message=message) # Update conversation store.message_set(message) - index.message_index(message, current_user.id) conversation = store.conversation_get(conversation_id, current_user.id) if not conversation: logger.warn("ACID error: conversation not found after testing existence") @@ -373,6 +308,19 @@ async def message_post( status_code=status.HTTP_404_NOT_FOUND, detail="Conversation not found", ) + await _message_index(message, current_user, conversation.prompt) + + # Build usage + usage = UsageModel( + ai_model=OAI_GPT_MODEL, + conversation_id=conversation_id, + created_at=datetime.utcnow(), + id=uuid4(), + tokens=tokens_nb, + user_id=current_user.id, + prompt_name=conversation.prompt.name if conversation.prompt else None, + ) + store.usage_set(usage) else: # Test prompt ID if provided if prompt_id and prompt_id not in AI_PROMPTS: @@ -381,49 +329,49 @@ async def message_post( detail="Prompt ID not found", ) - # Validate message length - tokens_nb = oai_tokens_nb(content, OAI_GPT_MODEL) - logger.debug(f"{tokens_nb} tokens in the conversation") - if tokens_nb > OAI_GPT_MAX_TOKENS: - logger.info(f"Message ({tokens_nb}) too long for conversation") - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Conversation history is too long", - ) + tokens_nb = await _validate_message_length(content=content) # Build conversation conversation = StoredConversationModel( - created_at=datetime.now(), + created_at=datetime.utcnow(), id=uuid4(), prompt=AI_PROMPTS[prompt_id] if prompt_id else None, user_id=current_user.id, ) store.conversation_set(conversation) + # Build usage + usage = UsageModel( + ai_model=OAI_GPT_MODEL, + conversation_id=conversation.id, + created_at=datetime.utcnow(), + id=uuid4(), + tokens=tokens_nb, + user_id=current_user.id, + prompt_name=conversation.prompt.name if conversation.prompt else None, + ) + store.usage_set(usage) + # Build message message = StoredMessageModel( content=content, conversation_id=conversation.id, - created_at=datetime.now(), + created_at=datetime.utcnow(), id=uuid4(), role=MessageRole.USER, secret=secret, token=uuid4(), ) store.message_set(message) - index.message_index(message, current_user.id) + await _message_index(message, current_user, conversation.prompt) messages = store.message_list(conversation.id) if conversation.title is None: - asyncio.get_running_loop().run_in_executor( - None, lambda: guess_title(conversation, messages, current_user) - ) + loop.create_task(_guess_title_background(conversation, messages, current_user)) # Execute the message completion - asyncio.get_running_loop().run_in_executor( - None, lambda: completion_from_conversation(conversation, messages, current_user) - ) + loop.create_task(_generate_completion_background(conversation, messages, current_user)) return GetConversationModel( **conversation.dict(), @@ -433,10 +381,10 @@ async def message_post( @api.get("/message/{id}") async def message_get(id: UUID, token: UUID, req: Request) -> EventSourceResponse: - return EventSourceResponse(read_message_sse(req, token)) + return EventSourceResponse(_read_message_sse(req, token)) -async def read_message_sse(req: Request, message_id: UUID): +async def _read_message_sse(req: Request, message_id: UUID): def clean(): logger.info(f"Cleared message cache (message_id={message_id})") stream.clean(message_id) @@ -463,15 +411,10 @@ async def loop_func() -> bool: async def message_search( q: str, current_user: Annotated[UserModel, Depends(get_current_user)] ) -> SearchModel: - return index.message_search(q, current_user.id) + return await index.message_search(q, current_user.id) -@retry( - reraise=True, - stop=stop_after_attempt(3), - wait=wait_random_exponential(multiplier=0.5, max=30), -) -def completion_from_conversation( +async def _generate_completion_background( conversation: StoredConversationModel, messages: List[MessageModel], current_user: UserModel, @@ -496,51 +439,79 @@ def completion_from_conversation( logger.debug(f"Completion messages: {completion_messages}") - try: - # Use chat completion to get a more natural response and lower the usage cost - chunks = openai.ChatCompletion.create( - deployment_id=OAI_GPT_DEPLOY_ID, - messages=completion_messages, - model=OAI_GPT_MODEL, - presence_penalty=1, # Increase the model's likelihood to talk about new topics - stream=True, - user=hash_token(current_user.id.bytes).hex, - ) - except openai.error.AuthenticationError as e: - logger.exception(e) - return - content_full = "" - for chunk in chunks: - content = chunk["choices"][0].get("delta", {}).get("content") - if content is not None: - logger.debug(f"Completion result: {content}") - # Add content to the redis stream cache_key - stream.push(content, last_message.token) - content_full += content + async for content in openai.completion_stream(completion_messages, current_user): + logger.debug(f"Completion result: {content}") + # Add content to the redis stream cache_key + stream.push(content, last_message.token) + content_full += content # First, store the updated conversation in Redis res_message = StoredMessageModel( content=content_full, conversation_id=conversation.id, - created_at=datetime.now(), + created_at=datetime.utcnow(), id=uuid4(), role=MessageRole.ASSISTANT, secret=last_message.secret, ) store.message_set(res_message) - index.message_index(res_message, current_user.id) + await _message_index(res_message, current_user, conversation.prompt) # Then, send the end of stream message stream.push(STREAM_STOPWORD, last_message.token) -@retry( - reraise=True, - stop=stop_after_attempt(3), - wait=wait_random_exponential(multiplier=0.5, max=30), -) -def guess_title( +async def _message_index(message: StoredMessageModel, current_user: UserModel, prompt: Optional[StoredPromptModel]) -> None: + tokens_nb = oai_tokens_nb(message.content, OAI_ADA_MODEL) + if tokens_nb > OAI_ADA_MAX_TOKENS: + logger.info(f"Message ({tokens_nb}) too long for indexing") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Conversation history is too long", + ) + + usage = UsageModel( + ai_model=OAI_ADA_MODEL, + conversation_id=message.conversation_id, + created_at=datetime.utcnow(), + id=uuid4(), + tokens=oai_tokens_nb(message.content, OAI_ADA_MODEL), + user_id=current_user.id, + prompt_name=prompt.name if prompt else None, + ) + store.usage_set(usage) + await index.message_index(message, current_user.id) + + +async def _validate_message_length( + message: Optional[StoredMessageModel] = None, + content: Optional[str] = None, +) -> int: + if content: + tokens_nb = oai_tokens_nb(content, OAI_GPT_MODEL) + elif message: + tokens_nb = oai_tokens_nb( + message.content + + "".join([m.content for m in store.message_list(message.conversation_id)]), + OAI_GPT_MODEL, + ) + else: + raise ValueError('Either message or content must be provided to "validate_usage"') + + logger.debug(f"{tokens_nb} tokens in the conversation") + + if tokens_nb > OAI_GPT_MAX_TOKENS: + logger.info(f"Message ({tokens_nb}) too long for conversation") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Conversation history is too long", + ) + + return tokens_nb + + +async def _guess_title_background( conversation: StoredConversationModel, messages: List[MessageModel], current_user: UserModel, @@ -556,67 +527,8 @@ def guess_title( logger.debug(f"Completion messages: {completion_messages}") - try: - # Use chat completion to get a more natural response and lower the usage cost - completion = openai.ChatCompletion.create( - deployment_id=OAI_GPT_DEPLOY_ID, - messages=completion_messages, - model=OAI_GPT_MODEL, - presence_penalty=1, # Increase the model's likelihood to talk about new topics - user=hash_token(current_user.id.bytes).hex, - ) - content = completion["choices"][0].message.content - except openai.error.AuthenticationError as e: - logger.exception(e) - return + content = await openai.completion(completion_messages, current_user) # Store the updated conversation in Redis conversation.title = content store.conversation_set(conversation) - - -@retry( - reraise=True, - stop=stop_after_attempt(3), - wait=wait_random_exponential(multiplier=0.5, max=30), -) -async def is_moderated(prompt: str) -> bool: - logger.debug(f"Checking moderation for text: {prompt}") - - if len(prompt) > ACS_MAX_LENGTH: - logger.info(f"Message ({len(prompt)}) too long for moderation") - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Message too long", - ) - - req = azure_cs.models.AnalyzeTextOptions( - text=prompt, - categories=[ - azure_cs.models.TextCategory.HATE, - azure_cs.models.TextCategory.SELF_HARM, - azure_cs.models.TextCategory.SEXUAL, - azure_cs.models.TextCategory.VIOLENCE, - ], - ) - - try: - res = acs_client.analyze_text(req) - except azure_exceptions.ClientAuthenticationError as e: - logger.exception(e) - return False - - is_moderated = any( - cat.severity >= ACS_SEVERITY_THRESHOLD - for cat in [ - res.hate_result, - res.self_harm_result, - res.sexual_result, - res.violence_result, - ] - ) - if is_moderated: - logger.info(f"Message is moderated: {prompt}") - logger.debug(f"Moderation result: {res}") - - return is_moderated diff --git a/src/conversation-api/models/conversation.py b/src/conversation-api/models/conversation.py index 6ea5ae8e..09f3db4e 100644 --- a/src/conversation-api/models/conversation.py +++ b/src/conversation-api/models/conversation.py @@ -10,7 +10,7 @@ class BaseConversationModel(BaseModel): created_at: datetime id: UUID title: Optional[str] = None - user_id: UUID + user_id: UUID # Partition key class StoredConversationModel(BaseConversationModel): diff --git a/src/conversation-api/models/message.py b/src/conversation-api/models/message.py index ebb73411..d3b123fc 100644 --- a/src/conversation-api/models/message.py +++ b/src/conversation-api/models/message.py @@ -21,7 +21,7 @@ class MessageModel(BaseModel): class StoredMessageModel(MessageModel): - conversation_id: UUID + conversation_id: UUID # Partition key class IndexMessageModel(BaseModel): diff --git a/src/conversation-api/models/usage.py b/src/conversation-api/models/usage.py new file mode 100644 index 00000000..aec7bbb4 --- /dev/null +++ b/src/conversation-api/models/usage.py @@ -0,0 +1,15 @@ +from datetime import datetime +from typing import Optional +from pydantic import BaseModel +from uuid import UUID + + +class UsageModel(BaseModel): + ai_model: str + conversation_id: UUID + created_at: datetime + id: UUID + prompt_name: Optional[str] = None + tokens: int + user_id: UUID # Partition key + diff --git a/src/conversation-api/persistence/cosmos.py b/src/conversation-api/persistence/cosmos.py index a1eca0a9..9c42f019 100644 --- a/src/conversation-api/persistence/cosmos.py +++ b/src/conversation-api/persistence/cosmos.py @@ -3,13 +3,13 @@ # Import misc from .istore import IStore -from azure.cosmos import CosmosClient, PartitionKey, ConsistencyLevel -from azure.cosmos.database import DatabaseProxy -from azure.cosmos.exceptions import CosmosHttpResponseError, CosmosResourceExistsError +from azure.cosmos import CosmosClient, ConsistencyLevel +from azure.cosmos.exceptions import CosmosHttpResponseError from azure.identity import DefaultAzureCredential from datetime import datetime from models.conversation import StoredConversationModel, StoredConversationModel from models.message import MessageModel, IndexMessageModel, StoredMessageModel +from models.usage import UsageModel from models.user import UserModel from typing import (Any, Dict, List, Union) from uuid import UUID @@ -30,6 +30,7 @@ conversation_client = database.get_container_client("conversation") message_client = database.get_container_client("message") user_client = database.get_container_client("user") +usage_client = database.get_container_client("usage") logger.info(f'Connected to Cosmos DB at "{DB_URL}"') @@ -67,7 +68,7 @@ def conversation_set(self, conversation: StoredConversationModel) -> None: conversation_client.upsert_item(body=self._sanitize_before_insert(conversation.dict())) def conversation_list(self, user_id: UUID) -> List[StoredConversationModel]: - query = f"SELECT * FROM c WHERE c.user_id = '{user_id}'" + query = f"SELECT * FROM c WHERE c.user_id = '{user_id}' ORDER BY c.created_at DESC" items = conversation_client.query_items(query=query, enable_cross_partition_query=True) return [StoredConversationModel(**item) for item in items] @@ -102,14 +103,21 @@ def message_set(self, message: StoredMessageModel) -> None: }) def message_list(self, conversation_id: UUID) -> List[MessageModel]: - query = f"SELECT * FROM c WHERE c.conversation_id = '{conversation_id}'" + query = f"SELECT * FROM c WHERE c.conversation_id = '{conversation_id}' ORDER BY c.created_at ASC" items = message_client.query_items(query=query, enable_cross_partition_query=True) return [MessageModel(**item) for item in items] - def _sanitize_before_insert(self, item: dict) -> Dict[str, Union[str, int, float, bool]]: + def usage_set(self, usage: UsageModel) -> None: + usage_client.upsert_item(body=self._sanitize_before_insert(usage.dict())) + + def _sanitize_before_insert(self, item: dict) -> Dict[str, Any]: for key, value in item.items(): if isinstance(value, UUID): item[key] = str(value) elif isinstance(value, datetime): item[key] = value.isoformat() + elif isinstance(value, dict): + item[key] = self._sanitize_before_insert(value) + elif isinstance(value, list): + item[key] = [self._sanitize_before_insert(i) for i in value] return item diff --git a/src/conversation-api/persistence/isearch.py b/src/conversation-api/persistence/isearch.py index f399ae04..fdb69111 100644 --- a/src/conversation-api/persistence/isearch.py +++ b/src/conversation-api/persistence/isearch.py @@ -15,9 +15,9 @@ def __init__(self, store: IStore): self.store = store @abstractmethod - def message_search(self, query: str, user_id: UUID) -> SearchModel[MessageModel]: + async def message_search(self, query: str, user_id: UUID) -> SearchModel[MessageModel]: pass @abstractmethod - def message_index(self, message: StoredMessageModel, user_id: UUID) -> None: + async def message_index(self, message: StoredMessageModel, user_id: UUID) -> None: pass diff --git a/src/conversation-api/persistence/istore.py b/src/conversation-api/persistence/istore.py index 0433db74..5bd4a76b 100644 --- a/src/conversation-api/persistence/istore.py +++ b/src/conversation-api/persistence/istore.py @@ -3,6 +3,7 @@ from models.conversation import GetConversationModel, StoredConversationModel from models.message import MessageModel, IndexMessageModel, StoredMessageModel from models.user import UserModel +from models.usage import UsageModel from typing import List, Union from uuid import UUID @@ -58,3 +59,7 @@ def message_set(self, message: StoredMessageModel) -> None: @abstractmethod def message_list(self, conversation_id: UUID) -> List[MessageModel]: pass + + @abstractmethod + def usage_set(self, usage: UsageModel) -> None: + pass diff --git a/src/conversation-api/persistence/qdrant.py b/src/conversation-api/persistence/qdrant.py index a70c768c..77817b61 100644 --- a/src/conversation-api/persistence/qdrant.py +++ b/src/conversation-api/persistence/qdrant.py @@ -4,21 +4,20 @@ # Import misc from .isearch import ISearch from .istore import IStore +from ai.openai import OpenAI from datetime import datetime from models.message import MessageModel, IndexMessageModel, StoredMessageModel from models.search import SearchModel, SearchStatsModel, SearchAnswerModel from qdrant_client import QdrantClient -from tenacity import retry, stop_after_attempt, wait_random_exponential -from typing import List from uuid import UUID import asyncio -import openai import qdrant_client.http.models as qmodels import textwrap import time logger = build_logger(__name__) +openai = OpenAI() QD_COLLECTION = "messages" QD_DIMENSION = 1536 QD_HOST = get_config("qd", "host", str, required=True) @@ -27,21 +26,12 @@ client = QdrantClient(host=QD_HOST, port=6333) logger.info(f'Connected to Qdrant at "{QD_HOST}:{QD_PORT}"') -OAI_ADA_DEPLOY_ID = get_config("openai", "ada_deploy_id", str, required=True) -OAI_ADA_MAX_TOKENS = get_config("openai", "ada_max_tokens", int, required=True) -OAI_ADA_MODEL = get_config( - "openai", "ada_model", str, default="text-embedding-ada-002", required=True -) -logger.info( - f'Using OpenAI ADA model "{OAI_ADA_MODEL}" ({OAI_ADA_DEPLOY_ID}) with {OAI_ADA_MAX_TOKENS} tokens max' -) - class QdrantSearch(ISearch): def __init__(self, store: IStore): super().__init__(store) - self._loop = asyncio.new_event_loop() + self._loop = asyncio.get_running_loop() # Ensure collection exists try: @@ -55,14 +45,14 @@ def __init__(self, store: IStore): ), ) - def message_search(self, q: str, user_id: UUID) -> SearchModel[MessageModel]: + async def message_search(self, q: str, user_id: UUID) -> SearchModel[MessageModel]: logger.debug(f"Searching for: {q}") start = time.monotonic() - vector = self._vector_from_text( + vector = await openai.vector_from_text( textwrap.dedent( f""" - Today, we are the {datetime.now()}. {q.capitalize()} + Today, we are the {datetime.utcnow()}. {q.capitalize()} """ ), user_id, @@ -103,20 +93,18 @@ def message_search(self, q: str, user_id: UUID) -> SearchModel[MessageModel]: stats=SearchStatsModel(total=total, time=time.monotonic() - start), ) - def message_index( + async def message_index( self, message: StoredMessageModel, user_id: UUID ) -> None: logger.debug(f"Indexing message: {message.id}") - self._loop.run_in_executor( - None, lambda: self._index_worker(message, user_id) - ) + self._loop.create_task(self._index_background(message, user_id)) - def _index_worker( + async def _index_background( self, message: StoredMessageModel, user_id: UUID ) -> None: logger.debug(f"Starting indexing worker for message: {message.id}") - vector = self._vector_from_text(message.content, user_id) + vector = await openai.vector_from_text(message.content, user_id) index = IndexMessageModel( conversation_id=message.conversation_id, id=message.id, @@ -131,23 +119,3 @@ def _index_worker( vectors=[vector], ), ) - - @retry( - reraise=True, - stop=stop_after_attempt(3), - wait=wait_random_exponential(multiplier=0.5, max=30), - ) - def _vector_from_text(self, prompt: str, user_id: UUID) -> List[float]: - logger.debug(f"Getting vector for text: {prompt}") - try: - res = openai.Embedding.create( - deployment_id=OAI_ADA_DEPLOY_ID, - input=prompt, - model=OAI_ADA_MODEL, - user=user_id.hex, - ) - except openai.error.AuthenticationError as e: - logger.exception(e) - return [] - - return res.data[0].embedding diff --git a/src/conversation-api/persistence/redis.py b/src/conversation-api/persistence/redis.py index 4f2e0c74..98c33e1c 100644 --- a/src/conversation-api/persistence/redis.py +++ b/src/conversation-api/persistence/redis.py @@ -7,6 +7,7 @@ from models.conversation import StoredConversationModel, StoredConversationModel from models.message import MessageModel, IndexMessageModel, StoredMessageModel from models.user import UserModel +from models.usage import UsageModel from redis import Redis from typing import (Any, AsyncGenerator, Callable, Awaitable, List, Literal, Optional, Union) from uuid import UUID @@ -18,11 +19,12 @@ # Configuration CONVERSATION_PREFIX = "conversation" -MESSAGE_PREFIX = "message" DB_HOST = get_config("redis", "host", str, required=True) DB_PORT = 6379 +MESSAGE_PREFIX = "message" STREAM_PREFIX = "stream" STREAM_STOPWORD = "STOP" +USAGE_PREFIX = "usage" USER_PREFIX = "user" # Redis client @@ -129,6 +131,12 @@ def message_list(self, conversation_id: UUID) -> List[MessageModel]: messages.sort(key=lambda x: x.created_at) return messages + def usage_set(self, usage: UsageModel) -> None: + client.set(self._usage_cache_key(usage.user_id), usage.json()) + + def _usage_cache_key(self, user_id: UUID) -> str: + return f"{USAGE_PREFIX}:{user_id.hex}" + def _conversation_cache_key( self, user_id: UUID, conversation_id: Optional[UUID] = None ) -> str: @@ -188,7 +196,8 @@ async def get( if message_loop: yield message_loop - await asyncio.sleep(0.25) + # 8 messages per second, enough for give a good user experience, but not too much for not using the thread too much + await asyncio.sleep(0.125) # Send the end of stream message yield STREAM_STOPWORD