Skip to content

Commit

Permalink
speed up local embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
jonaskahn committed Sep 22, 2024
1 parent 3c909b5 commit a8d5f0a
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 16 deletions.
2 changes: 1 addition & 1 deletion engine/engine/database/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class Chat(Model):
question = TextField(null=False)
refined_question = TextField(null=False)
answer = TextField(null=False)
context = TextField(null=False, default=""),
relevant_docs = TextField(null=False, default="")
prompt = TextField(null=False, default="")
provider = TextField(null=False)

Expand Down
25 changes: 22 additions & 3 deletions engine/engine/services/ai_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import random
import tempfile
from collections import Counter
from concurrent.futures import as_completed
from concurrent.futures.thread import ThreadPoolExecutor
from uuid import uuid4

import anthropic
Expand Down Expand Up @@ -286,11 +288,25 @@ def embed_document_with_mistral(text: str, max_tokens=8000) -> tuple[list[str],
text in texts]

@staticmethod
def embed_document_with_local(text: str, max_tokens=16000) -> tuple[list[str], list[list[float]]]:
def embed_document_with_local(text: str, max_tokens=512) -> tuple[list[str], list[list[float]]]:
texts = AiService.__chunk_text(text, max_tokens)
embedding_texts = []
embedding_vectors = []

with ThreadPoolExecutor(max_workers=3) as executor:
futures = [executor.submit(AiService.__internal_local_embed_text, text) for text in texts]

for future in as_completed(futures):
t1, e1 = future.result()
embedding_texts.append(t1)
embedding_vectors.append(e1)

return embedding_texts, embedding_vectors

@staticmethod
def __internal_local_embed_text(text: str) -> tuple[str, list[float]]:
encoder = AiService.__get_local_embedding_encoder()
return texts, [encoder.encode([text], normalize_embeddings=True, convert_to_numpy=True).tolist()[0] for text in
texts]
return text, encoder.encode([text], normalize_embeddings=True, convert_to_numpy=True).tolist()[0]

@staticmethod
def store_embeddings(table: str, ids: list[str], texts: list[str], embeddings: list[list[float]]):
Expand Down Expand Up @@ -333,6 +349,9 @@ def query_embeddings(table: str, queries: list[list[list[float]]], thresholds: l
Note:
The method ensures that duplicate documents are not returned, even if they match multiple queries or thresholds.
"""
if not queries:
return 0, []

if thresholds is None:
thresholds = [env.QUERY_SIMILAR_THRESHOLD]
collection = chromadb_client.get_or_create_collection(table)
Expand Down
10 changes: 5 additions & 5 deletions engine/engine/services/chat_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,13 @@ async def __ask_with_rag(question: str, video: Video, chats: list[Chat], provide
previous_questions = "\n".join([chat.question for chat in previous_chats])
context_document = ChatService.__get_relevant_doc(provider=provider, model=model, video=video, question=question, previous_questions=previous_questions)
logger.debug(f"Relevant docs: {context_document}")
context = ASKING_PROMPT_WITH_RAG.format(**{"context": context_document}) if context_document else None
prompt_context = ASKING_PROMPT_WITH_RAG.format(**{"context": context_document}) if context_document else None

if not context and env.RAG_AUTO_SWITCH in ["on", "yes", "enabled"]:
if not prompt_context and env.RAG_AUTO_SWITCH in ["on", "yes", "enabled"]:
logger.debug("RAG is required, but none relevant information found, auto switch")
return await ChatService.__ask_without_rag(question=question, video=video, chats=chats, provider=provider, model=model)

awareness_context = context if context else "No video information related, just answer me in your ability"
awareness_context = prompt_context if prompt_context else "No video information related, just answer me in your ability"
match provider:
case "gemini":
result = await ChatService.__ask_gemini_with_rag(model=model, question=question, context=awareness_context, chats=chats)
Expand All @@ -235,8 +235,8 @@ async def __ask_with_rag(question: str, video: Video, chats: list[Chat], provide
question=question,
refined_question="Not Implemented Yet",
answer=result,
context=context_document if context else "No context doc found",
prompt=awareness_context if awareness_context else "No prompt found",
relevant_docs=context_document if context_document else "No context doc found",
prompt=prompt_context if prompt_context else "No prompt found",
provider=provider
)
chat.save()
Expand Down
20 changes: 15 additions & 5 deletions engine/engine/services/video_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from engine.database.models import Video, VideoChapter
from engine.database.specs import sqlite_client, chromadb_client
from engine.services.ai_service import AiService
from engine.supports import constants
from engine.supports import constants, env
from engine.supports.errors import VideoError, AiError
from engine.supports.prompts import SUMMARY_PROMPT

Expand Down Expand Up @@ -85,11 +85,17 @@ async def __internal_analysis(video):
VideoService.__prepare_video_transcript(video, video_chapters)
logger.debug(f"[{trace_id}] finish prepare video chapters")

logger.debug(f"[{trace_id}] start embedding video transcript")
video.total_parts = await VideoService.__analysis_chapters(video_chapters, video.embedding_provider)
if video.transcript_tokens > env.TOKEN_CONTEXT_THRESHOLD:
logger.debug(f"[{trace_id}] start embedding video transcript")
video.total_parts = await VideoService.__analysis_chapters(video_chapters, video.embedding_provider)
logger.debug(f"[{trace_id}] finish embedding video transcript")
else:
logger.debug("explicitly skip to analysis video")
video.total_parts = len(video_chapters)

video.analysis_state = constants.ANALYSIS_STAGE_COMPLETED
VideoService.save(video, video_chapters)
logger.debug(f"[{trace_id}] finish embedding video transcript")

logger.debug(f"finish analysis video: {video.title}")
except Exception as e:
VideoService.__update_analysis_content_state(video, constants.ANALYSIS_STAGE_INITIAL)
Expand Down Expand Up @@ -168,7 +174,7 @@ def __update_analysis_summary_state(video: Video, state: int):

@staticmethod
async def __analysis_chapters(video_chapters: list[VideoChapter], provider: str) -> int:
with ThreadPoolExecutor(max_workers=len(video_chapters)) as executor:
with ThreadPoolExecutor(max_workers=5) as executor:
if provider == "gemini":
futures = [executor.submit(VideoService.__analysis_video_with_gemini, chapter) for chapter in
video_chapters]
Expand Down Expand Up @@ -272,6 +278,10 @@ async def analysis_summary_video(vid: int, model: str, provider: str):
raise VideoError("video is not found")
if video.analysis_summary_state in [constants.ANALYSIS_STAGE_COMPLETED, constants.ANALYSIS_STAGE_PROCESSING]:
return
if video.transcript_tokens <= env.TOKEN_CONTEXT_THRESHOLD:
logger.debug("explicitly skip to analysis video summary")
VideoService.__update_analysis_summary_state(video, constants.ANALYSIS_STAGE_COMPLETED)
return
logger.debug("start analysis summary video")
VideoService.__update_analysis_summary_state(video, constants.ANALYSIS_STAGE_PROCESSING)
try:
Expand Down
2 changes: 1 addition & 1 deletion engine/engine/supports/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
AUDIO_CHUNK_RECOGNIZE_THRESHOLD: int = int(os.getenv("AT_AUDIO_CHUNK_RECOGNIZE_THRESHOLD", 120))
AUDIO_CHUNK_CHAPTER_DURATION: int = int(os.getenv("AT_AUDIO_CHUNK_CHAPTER_DURATION", 600))
QUERY_SIMILAR_THRESHOLD: float = float(os.getenv("AT_QUERY_SIMILAR_THRESHOLD", 0.4))
TOKEN_CONTEXT_THRESHOLD: int = int(os.getenv("AT_TOKEN_CONTEXT_THRESHOLD", 2048))
TOKEN_CONTEXT_THRESHOLD: int = int(os.getenv("AT_TOKEN_CONTEXT_THRESHOLD", 8192))
AUDIO_ENHANCE_ENABLED: str = os.getenv("AT_AUDIO_ENHANCE_ENABLED", "off")
RAG_QUERY_IMPLEMENTATION: str = os.getenv("AT_RAG_QUERY_IMPLEMENTATION", "multiquery")
RAG_AUTO_SWITCH: str = os.getenv("AT_RAG_AUTO_SWITCH", "on")
Expand Down
2 changes: 1 addition & 1 deletion engine/engine/supports/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
No yapping!!!
"""

ASKING_PROMPT_WITH_RAG = """Here is related video information for question you can reference:
ASKING_PROMPT_WITH_RAG = """Here is related video information can be used for question that you can reference:
{context}
"""

Expand Down

0 comments on commit a8d5f0a

Please sign in to comment.