From 2fe490b165b40ab74494d360720a6e5ae80b1f63 Mon Sep 17 00:00:00 2001 From: Tim Cremer Date: Thu, 19 Dec 2024 15:13:13 +0100 Subject: [PATCH 01/19] Initial setup of faq ingestion and deletion --- app/common/PipelineEnum.py | 1 + app/domain/data/faq_dto.py | 13 ++ .../ingestion/deletionPipelineExecutionDto.py | 8 + .../ingestion_pipeline_execution_dto.py | 9 ++ app/pipeline/faq_ingestion_pipeline.py | 151 ++++++++++++++++++ app/vector_database/faq_schema.py | 96 +++++++++++ app/web/routers/pipelines.py | 10 ++ app/web/routers/webhooks.py | 92 ++++++++++- .../status/faq_ingestion_status_callback.py | 45 ++++++ docker/pyris-dev.yml | 2 + 10 files changed, 425 insertions(+), 2 deletions(-) create mode 100644 app/domain/data/faq_dto.py create mode 100644 app/pipeline/faq_ingestion_pipeline.py create mode 100644 app/vector_database/faq_schema.py create mode 100644 app/web/status/faq_ingestion_status_callback.py diff --git a/app/common/PipelineEnum.py b/app/common/PipelineEnum.py index fc439a65..6d951f6e 100644 --- a/app/common/PipelineEnum.py +++ b/app/common/PipelineEnum.py @@ -14,4 +14,5 @@ class PipelineEnum(str, Enum): IRIS_SUMMARY_PIPELINE = "IRIS_SUMMARY_PIPELINE" IRIS_LECTURE_RETRIEVAL_PIPELINE = "IRIS_LECTURE_RETRIEVAL_PIPELINE" IRIS_LECTURE_INGESTION = "IRIS_LECTURE_INGESTION" + IRIS_FAQ_INGESTION = "IRIS_FAQ_INGESTION" NOT_SET = "NOT_SET" diff --git a/app/domain/data/faq_dto.py b/app/domain/data/faq_dto.py new file mode 100644 index 00000000..c4ba550e --- /dev/null +++ b/app/domain/data/faq_dto.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel, Field + + +class FaqDTO(BaseModel): + faq_id: int = Field(alias="faqId") + course_id: int = Field(alias="courseId") + questionTitle: str = Field(alias="questionTitle") + questionAnswer: str = Field(alias="questionAnswer"), + course_name: str = Field(default="", alias="courseName") + course_description: str = Field(default="", alias="courseDescription") + + + diff --git a/app/domain/ingestion/deletionPipelineExecutionDto.py b/app/domain/ingestion/deletionPipelineExecutionDto.py index 1cec7cdd..84445616 100644 --- a/app/domain/ingestion/deletionPipelineExecutionDto.py +++ b/app/domain/ingestion/deletionPipelineExecutionDto.py @@ -3,6 +3,7 @@ from pydantic import Field from app.domain import PipelineExecutionDTO, PipelineExecutionSettingsDTO +from app.domain.data.faq_dto import FaqDTO from app.domain.data.lecture_unit_dto import LectureUnitDTO from app.domain.status.stage_dto import StageDTO @@ -13,3 +14,10 @@ class LecturesDeletionExecutionDto(PipelineExecutionDTO): initial_stages: Optional[List[StageDTO]] = Field( default=None, alias="initialStages" ) + +class FaqDeletionExecutionDto(PipelineExecutionDTO): + faq: FaqDTO = Field(..., alias="pyrisFaqWebhookDTO") + settings: Optional[PipelineExecutionSettingsDTO] + initial_stages: Optional[List[StageDTO]] = Field( + default=None, alias="initialStages" + ) \ No newline at end of file diff --git a/app/domain/ingestion/ingestion_pipeline_execution_dto.py b/app/domain/ingestion/ingestion_pipeline_execution_dto.py index 12f3205f..213be158 100644 --- a/app/domain/ingestion/ingestion_pipeline_execution_dto.py +++ b/app/domain/ingestion/ingestion_pipeline_execution_dto.py @@ -3,6 +3,7 @@ from pydantic import Field from app.domain import PipelineExecutionDTO, PipelineExecutionSettingsDTO +from app.domain.data.faq_dto import FaqDTO from app.domain.data.lecture_unit_dto import LectureUnitDTO from app.domain.status.stage_dto import StageDTO @@ -13,3 +14,11 @@ class IngestionPipelineExecutionDto(PipelineExecutionDTO): initial_stages: Optional[List[StageDTO]] = Field( default=None, alias="initialStages" ) + +class FaqIngestionPipelineExecutionDto(PipelineExecutionDTO): + faq: FaqDTO = Field(..., alias="pyrisFaqWebhookDTO") + settings: Optional[PipelineExecutionSettingsDTO] + initial_stages: Optional[List[StageDTO]] = Field( + default=None, alias="initialStages" + ) + diff --git a/app/pipeline/faq_ingestion_pipeline.py b/app/pipeline/faq_ingestion_pipeline.py new file mode 100644 index 00000000..7b265264 --- /dev/null +++ b/app/pipeline/faq_ingestion_pipeline.py @@ -0,0 +1,151 @@ +import logging +import threading +from asyncio.log import logger +from typing import Optional, List, Dict +from langchain_core.output_parsers import StrOutputParser +from openai import OpenAI +from weaviate import WeaviateClient +from weaviate.classes.query import Filter +from . import Pipeline +from ..domain.data.faq_dto import FaqDTO + +from app.domain.ingestion.ingestion_pipeline_execution_dto import ( + FaqIngestionPipelineExecutionDto, +) +from ..llm.langchain import IrisLangchainChatModel +from ..vector_database.faq_schema import FaqSchema, init_faq_schema +from ..ingestion.abstract_ingestion import AbstractIngestion +from ..llm import ( + BasicRequestHandler, + CompletionArguments, + CapabilityRequestHandler, + RequirementList, +) +from ..web.status.faq_ingestion_status_callback import FaqIngestionStatus + +batch_update_lock = threading.Lock() + +class FaqIngestionPipeline(AbstractIngestion, Pipeline): + + def __init__( + self, + client: WeaviateClient, + dto: Optional[FaqIngestionPipelineExecutionDto], + callback: FaqIngestionStatus, + ): + super().__init__() + self.collection = init_faq_schema(client) + self.dto = dto + self.llm_vision = BasicRequestHandler("azure-gpt-4-omni") + self.llm_chat = BasicRequestHandler("azure-gpt-35-turbo") + self.llm_embedding = BasicRequestHandler("embedding-small") + self.callback = callback + request_handler = CapabilityRequestHandler( + requirements=RequirementList( + gpt_version_equivalent=3.5, + context_length=16385, + privacy_compliance=True, + ) + ) + completion_args = CompletionArguments(temperature=0.2, max_tokens=2000) + self.llm = IrisLangchainChatModel( + request_handler=request_handler, completion_args=completion_args + ) + self.pipeline = self.llm | StrOutputParser() + self.tokens = [] + + def __call__(self) -> bool: + try: + self.callback.in_progress("Deleting old faq from database...") + self.delete_faq( + self.dto.faq.faq_id, + self.dto.faq.course_id, + ) + self.callback.done("Old faq removed") + self.callback.in_progress("Ingesting faqs into database...") + self.batch_update(self.dto.faq) + self.callback.done("Faq Ingestion Finished", tokens=self.tokens) + logger.info( + f"Faq ingestion pipeline finished Successfully for faq: {self.dto.faq.faq_id}" + ) + return True + except Exception as e: + logger.error(f"Error updating faq: {e}") + self.callback.error( + f"Failed to faq into the database: {e}", + exception=e, + tokens=self.tokens, + ) + return False + + def batch_update(self, faq: FaqDTO): + """ + Batch update the faq into the database + This method is thread-safe and can only be executed by one thread at a time. + Weaviate limitation. + """ + global batch_update_lock + with batch_update_lock: + with self.collection.batch.rate_limit(requests_per_minute=600) as batch: + try: + # this needs to be working, otherwise its working just fine + # embed_chunk = self.llm_embedding.embed( + # f"{faq.questionTitle} : {faq.questionAnswer}" + # ) + + embed_chunk = [0.125, -1.179,-14.4,7.6,7.97] + faq_dict = faq.model_dump() + batch.add_object(properties=faq_dict, vector=embed_chunk) + for item in self.collection.iterator(): + logging.info(item) + + + except Exception as e: + logger.error(f"Error updating faq: {e}") + self.callback.error( + f"Failed to ingest faqs into the database: {e}", + exception=e, + tokens=self.tokens, + ) + + def delete_old_faqs( + self, faqs: list[FaqDTO] + ): + """ + Delete the faq from the database + """ + try: + for faq in faqs: + if self.delete_faq(faq.faq_id, faq.course_id): + logger.info("Faq deleted successfully") + else: + logger.error("Failed to delete faq") + self.callback.done("Old faqs removed") + except Exception as e: + logger.error(f"Error deleting faqs: {e}") + self.callback.error("Error while removing old faqs") + return False + + def delete_faq(self, faq_id, course_id): + """ + Delete the faq from the database + """ + try: + self.collection.data.delete_many( + where=Filter.by_property(FaqSchema.FAQ_ID.value).equal(faq_id) + & Filter.by_property(FaqSchema.COURSE_ID.value).equal(course_id) + + ) + logging.info(f"successfully deleted faq with id {faq_id}") + return True + except Exception as e: + logger.error(f"Error deleting faq: {e}", exc_info=True) + return False + + + def chunk_data(self, path: str) -> List[Dict[str, str]]: + """ + Faqs are so small, they do not need to be chunked into smaller parts + """ + return + diff --git a/app/vector_database/faq_schema.py b/app/vector_database/faq_schema.py new file mode 100644 index 00000000..a81f3e40 --- /dev/null +++ b/app/vector_database/faq_schema.py @@ -0,0 +1,96 @@ +import logging +from enum import Enum + +from weaviate.classes.config import Property +from weaviate import WeaviateClient +from weaviate.collections import Collection +from weaviate.collections.classes.config import Configure, VectorDistances, DataType + + +class FaqSchema(Enum): + """ + Schema for the faqs + """ + + COLLECTION_NAME = "Faqs" + COURSE_NAME = "course_name" + COURSE_DESCRIPTION = "course_description" + COURSE_LANGUAGE = "course_language" + COURSE_ID = "course_id" + FAQ_ID = "faq_id" + QUESTION_TITLE = "question_title" + QUESTION_Answer = "question_answer" + + +def init_faq_schema(client: WeaviateClient) -> Collection: + """ + Initialize the schema for the faqs + """ + if client.collections.exists(FaqSchema.COLLECTION_NAME.value): + collection = client.collections.get(FaqSchema.COLLECTION_NAME.value) + properties = collection.config.get(simple=True).properties + + # Check and add 'course_language' property if missing + if not any( + property.name == FaqSchema.COURSE_LANGUAGE.value + for property in collection.config.get(simple=False).properties + ): + collection.config.add_property( + Property( + name=FaqSchema.COURSE_LANGUAGE.value, + description="The language of the COURSE", + data_type=DataType.TEXT, + index_searchable=False, + ) + ) + return collection + + return client.collections.create( + name=FaqSchema.COLLECTION_NAME.value, + vectorizer_config=Configure.Vectorizer.none(), + vector_index_config=Configure.VectorIndex.hnsw( + distance_metric=VectorDistances.COSINE + ), + properties=[ + Property( + name=FaqSchema.COURSE_ID.value, + description="The ID of the course", + data_type=DataType.INT, + index_searchable=False, + ), + Property( + name=FaqSchema.COURSE_NAME.value, + description="The name of the course", + data_type=DataType.TEXT, + index_searchable=False, + ), + Property( + name=FaqSchema.COURSE_DESCRIPTION.value, + description="The description of the COURSE", + data_type=DataType.TEXT, + index_searchable=False, + ), + Property( + name=FaqSchema.COURSE_LANGUAGE.value, + description="The language of the COURSE", + data_type=DataType.TEXT, + index_searchable=False, + ), + Property( + name=FaqSchema.FAQ_ID.value, + description="The ID of the Faq", + data_type=DataType.INT, + index_searchable=False, + ), + Property( + name=FaqSchema.QUESTION_TITLE.value, + description="The title of the faq", + data_type=DataType.TEXT, + ), + Property( + name=FaqSchema.QUESTION_Answer.value, + description="The answer of the faq", + data_type=DataType.TEXT, + ), + ], + ) diff --git a/app/web/routers/pipelines.py b/app/web/routers/pipelines.py index fbc3c9f3..5f96b7aa 100644 --- a/app/web/routers/pipelines.py +++ b/app/web/routers/pipelines.py @@ -294,5 +294,15 @@ def get_pipeline(feature: str): description="Default lecture chat variant.", ) ] + + case "FAQ_INGESTION": + return [ + FeatureDTO( + id="default", + name="Default Variant", + description="Default faq ingestion variant.", + ) + ] + case _: return Response(status_code=status.HTTP_400_BAD_REQUEST) diff --git a/app/web/routers/webhooks.py b/app/web/routers/webhooks.py index 739a9bbb..d11db586 100644 --- a/app/web/routers/webhooks.py +++ b/app/web/routers/webhooks.py @@ -1,3 +1,4 @@ +import logging import traceback from asyncio.log import logger from threading import Thread, Semaphore @@ -7,13 +8,15 @@ from fastapi import APIRouter, status, Depends from app.dependencies import TokenValidator from app.domain.ingestion.ingestion_pipeline_execution_dto import ( - IngestionPipelineExecutionDto, + IngestionPipelineExecutionDto, FaqIngestionPipelineExecutionDto, ) +from ..status.faq_ingestion_status_callback import FaqIngestionStatus from ..status.ingestion_status_callback import IngestionStatusCallback from ..status.lecture_deletion_status_callback import LecturesDeletionStatusCallback from ...domain.ingestion.deletionPipelineExecutionDto import ( - LecturesDeletionExecutionDto, + LecturesDeletionExecutionDto, FaqDeletionExecutionDto, ) +from ...pipeline.faq_ingestion_pipeline import FaqIngestionPipeline from ...pipeline.lecture_ingestion_pipeline import LectureIngestionPipeline from ...vector_database.database import VectorDatabase @@ -40,6 +43,7 @@ def run_lecture_update_pipeline_worker(dto: IngestionPipelineExecutionDto): client=client, dto=dto, callback=callback ) pipeline() + except Exception as e: logger.error(f"Error Ingestion pipeline: {e}") logger.error(traceback.format_exc()) @@ -67,6 +71,60 @@ def run_lecture_deletion_pipeline_worker(dto: LecturesDeletionExecutionDto): logger.error(traceback.format_exc()) +def run_faq_update_pipeline_worker(dto: FaqIngestionPipelineExecutionDto): + """ + Run the exercise chat pipeline in a separate thread + """ + with semaphore: + try: + callback = FaqIngestionStatus( + run_id=dto.settings.authentication_token, + base_url=dto.settings.artemis_base_url, + initial_stages=dto.initial_stages, + faq_id=dto.faq.faq_id, + ) + db = VectorDatabase() + client = db.get_client() + pipeline = FaqIngestionPipeline( + client=client, dto=dto, callback=callback + ) + pipeline() + + + except Exception as e: + logger.error(f"Error Faq Ingestion pipeline: {e}") + logger.error(traceback.format_exc()) + capture_exception(e) + finally: + semaphore.release() + + +def run_faq_delete_pipeline_worker(dto: IngestionPipelineExecutionDto): + """ + Run the faq deletion in a separate thread + """ + with semaphore: + try: + callback = FaqIngestionStatus( + run_id=dto.settings.authentication_token, + base_url=dto.settings.artemis_base_url, + initial_stages=dto.initial_stages, + faq_id=dto.faq.faq_id, + ) + db = VectorDatabase() + client = db.get_client() + # Hier würd dann die Methode zum entfernen aus der Datenbank kommen + pipeline = FaqIngestionPipeline(client=client, dto=None, callback=callback) + pipeline.delete_faq(dto.faq.faq_id, dto.faq.course_id) + + + except Exception as e: + logger.error(f"Error Ingestion pipeline: {e}") + logger.error(traceback.format_exc()) + capture_exception(e) + finally: + semaphore.release() + @router.post( "/lectures/fullIngestion", status_code=status.HTTP_202_ACCEPTED, @@ -91,3 +149,33 @@ def lecture_deletion_webhook(dto: LecturesDeletionExecutionDto): """ thread = Thread(target=run_lecture_deletion_pipeline_worker, args=(dto,)) thread.start() + +@router.post( + "/faqs", + status_code=status.HTTP_202_ACCEPTED, + dependencies=[Depends(TokenValidator())], +) +def faq_ingestion_webhook(dto: FaqIngestionPipelineExecutionDto): + """ + Webhook endpoint to trigger the faq ingestion pipeline + """ + thread = Thread(target=run_faq_update_pipeline_worker, args=(dto,)) + thread.start() + return + +@router.post( + "/faqs/delete", + status_code=status.HTTP_202_ACCEPTED, + dependencies=[Depends(TokenValidator())], + ) +def faq_deletion_webhook(dto: FaqDeletionExecutionDto): + """ + Webhook endpoint to trigger the faq deletion pipeline + """ + logging.info(dto) + logging.info("Starting faq deletion") + thread = Thread(target=run_faq_delete_pipeline_worker, args=(dto,)) + thread.start() + return + + diff --git a/app/web/status/faq_ingestion_status_callback.py b/app/web/status/faq_ingestion_status_callback.py new file mode 100644 index 00000000..ad15c9c1 --- /dev/null +++ b/app/web/status/faq_ingestion_status_callback.py @@ -0,0 +1,45 @@ +from typing import List + +from .status_update import StatusCallback +from ...domain.ingestion.ingestion_status_update_dto import IngestionStatusUpdateDTO +from ...domain.status.stage_state_dto import StageStateEnum +from ...domain.status.stage_dto import StageDTO +import logging + +logger = logging.getLogger(__name__) + + +class FaqIngestionStatus(StatusCallback): + """ + Callback class for updating the status of a Faq ingestion Pipeline run. + """ + + def __init__( + self, + run_id: str, + base_url: str, + initial_stages: List[StageDTO] = None, + faq_id: int = None, + ): + url = f"{base_url}/api/public/pyris/webhooks/ingestion/runs/{run_id}/status" + + current_stage_index = len(initial_stages) if initial_stages else 0 + stages = initial_stages or [] + stages += [ + StageDTO( + weight=10, state=StageStateEnum.NOT_STARTED, name="Old faq removal" + ), + StageDTO( + weight=30, + state=StageStateEnum.NOT_STARTED, + name="Faq Interpretation", + ), + StageDTO( + weight=60, + state=StageStateEnum.NOT_STARTED, + name="Faq ingestion", + ), + ] + status = IngestionStatusUpdateDTO(stages=stages, id=faq_id) + stage = stages[current_stage_index] + super().__init__(url, run_id, status, stage, current_stage_index) diff --git a/docker/pyris-dev.yml b/docker/pyris-dev.yml index 7d1a956d..e50c17f8 100644 --- a/docker/pyris-dev.yml +++ b/docker/pyris-dev.yml @@ -14,6 +14,8 @@ services: - ../llm_config.local.yml:/config/llm_config.yml:ro networks: - pyris + ports: + - 8000:8000 weaviate: extends: From 19c0d811ec8c654e4511892ba2fed08ca3fe343e Mon Sep 17 00:00:00 2001 From: Tim Cremer Date: Wed, 25 Dec 2024 10:09:32 +0100 Subject: [PATCH 02/19] Initial setup of faq retrieval --- app/common/PipelineEnum.py | 1 + app/domain/data/faq_dto.py | 4 +- app/pipeline/chat/course_chat_pipeline.py | 58 ++- app/pipeline/faq_ingestion_pipeline.py | 16 +- app/pipeline/prompts/faq_retrieval_prompts.py | 10 + app/retrieval/faq_retrieval.py | 390 ++++++++++++++++++ app/vector_database/database.py | 3 + app/web/routers/webhooks.py | 1 - .../status/faq_ingestion_status_callback.py | 3 +- 9 files changed, 470 insertions(+), 16 deletions(-) create mode 100644 app/pipeline/prompts/faq_retrieval_prompts.py create mode 100644 app/retrieval/faq_retrieval.py diff --git a/app/common/PipelineEnum.py b/app/common/PipelineEnum.py index 6d951f6e..bcab28c7 100644 --- a/app/common/PipelineEnum.py +++ b/app/common/PipelineEnum.py @@ -15,4 +15,5 @@ class PipelineEnum(str, Enum): IRIS_LECTURE_RETRIEVAL_PIPELINE = "IRIS_LECTURE_RETRIEVAL_PIPELINE" IRIS_LECTURE_INGESTION = "IRIS_LECTURE_INGESTION" IRIS_FAQ_INGESTION = "IRIS_FAQ_INGESTION" + IRIS_FAQ_RETRIEVAL_PIPELINE = "IRIS_FAQ_RETRIEVAL_PIPELINE" NOT_SET = "NOT_SET" diff --git a/app/domain/data/faq_dto.py b/app/domain/data/faq_dto.py index c4ba550e..e68716af 100644 --- a/app/domain/data/faq_dto.py +++ b/app/domain/data/faq_dto.py @@ -4,8 +4,8 @@ class FaqDTO(BaseModel): faq_id: int = Field(alias="faqId") course_id: int = Field(alias="courseId") - questionTitle: str = Field(alias="questionTitle") - questionAnswer: str = Field(alias="questionAnswer"), + question_title: str = Field(alias="questionTitle") + question_answer: str = Field(alias="questionAnswer") course_name: str = Field(default="", alias="courseName") course_description: str = Field(default="", alias="courseDescription") diff --git a/app/pipeline/chat/course_chat_pipeline.py b/app/pipeline/chat/course_chat_pipeline.py index d934222f..feda706e 100644 --- a/app/pipeline/chat/course_chat_pipeline.py +++ b/app/pipeline/chat/course_chat_pipeline.py @@ -42,8 +42,10 @@ ) from ...domain import CourseChatPipelineExecutionDTO from app.common.PipelineEnum import PipelineEnum +from ...retrieval.faq_retrieval import FaqRetrieval from ...retrieval.lecture_retrieval import LectureRetrieval from ...vector_database.database import VectorDatabase +from ...vector_database.faq_schema import FaqSchema from ...vector_database.lecture_schema import LectureSchema from ...web.status.status_update import ( CourseChatStatusCallback, @@ -105,7 +107,8 @@ def __init__( self.callback = callback self.db = VectorDatabase() - self.retriever = LectureRetrieval(self.db.client) + self.lecture_retriever = LectureRetrieval(self.db.client) + self.faq_retriever = FaqRetrieval(self.db.client) self.suggestion_pipeline = InteractionSuggestionPipeline(variant="course") self.citation_pipeline = CitationPipeline() @@ -273,7 +276,7 @@ def lecture_content_retrieval() -> str: Only use this once. """ self.callback.in_progress("Retrieving lecture content ...") - self.retrieved_paragraphs = self.retriever( + self.retrieved_paragraphs = self.lecture_retriever( chat_history=history, student_query=query.contents[0].text_content, result_limit=5, @@ -293,6 +296,35 @@ def lecture_content_retrieval() -> str: result += lct return result + def faq_content_retrieval() -> str: + """ + Retrieve content from indexed faqs. + This will run a RAG retrieval based on the chat history on the indexed faqs and return the + most relevant paragraphs. + Use this if you think it can be useful to answer the student's question with a faq, or if the student explicitly asks + an organizational question about the course. + Only use this once. + """ + self.callback.in_progress("Retrieving faq content ...") + self.retrieved_paragraphs = self.faq_retriever( + chat_history=history, + student_query=query.contents[0].text_content, + result_limit=5, + course_name=dto.course.name, + course_id=dto.course.id, + base_url=dto.settings.artemis_base_url, + ) + + result = "" + for faq in self.retrieved_faqs: + res = "FAQ Title: {}, FAQ Answer: {}, ID: {}".format( + faq.get(FaqSchema.QUESTION_TITLE.value), + faq.get(FaqSchema.QUESTION_Answer.value), + faq.get(FaqSchema.FAQ_ID.value), + ) + result += res + return result + if dto.user.id % 3 < 2: iris_initial_system_prompt = tell_iris_initial_system_prompt begin_agent_prompt = tell_begin_agent_prompt @@ -391,6 +423,9 @@ def lecture_content_retrieval() -> str: if self.should_allow_lecture_tool(dto.course.id): tool_list.append(lecture_content_retrieval) + if self.should_allow_faq_tool(dto.course.id): + tool_list.append(faq_content_retrieval) + tools = generate_structured_tools_from_functions(tool_list) # No idea why we need this extra contrary to exercise chat agent in this case, but solves the issue. params.update({"tools": tools}) @@ -465,6 +500,25 @@ def should_allow_lecture_tool(self, course_id: int) -> bool: return len(result.objects) > 0 return False + def should_allow_faq_tool(self, course_id: int) -> bool: + """ + Checks if there are indexed faqs for the given course + + :param course_id: The course ID + :return: True if there are indexed lectures for the course, False otherwise + """ + if course_id: + # Fetch the first object that matches the course ID with the language property + result = self.db.faqs.query.fetch_objects( + filters=Filter.by_property(FaqSchema.COURSE_ID.value).equal( + course_id + ), + limit=1, + return_properties=[FaqSchema.COURSE_NAME.value], + ) + return len(result.objects) > 0 + return False + def datetime_to_string(dt: Optional[datetime]) -> str: if dt is None: diff --git a/app/pipeline/faq_ingestion_pipeline.py b/app/pipeline/faq_ingestion_pipeline.py index 7b265264..28d9fc5c 100644 --- a/app/pipeline/faq_ingestion_pipeline.py +++ b/app/pipeline/faq_ingestion_pipeline.py @@ -34,15 +34,14 @@ def __init__( callback: FaqIngestionStatus, ): super().__init__() + self.client = client self.collection = init_faq_schema(client) self.dto = dto - self.llm_vision = BasicRequestHandler("azure-gpt-4-omni") - self.llm_chat = BasicRequestHandler("azure-gpt-35-turbo") self.llm_embedding = BasicRequestHandler("embedding-small") self.callback = callback request_handler = CapabilityRequestHandler( requirements=RequirementList( - gpt_version_equivalent=3.5, + gpt_version_equivalent=4.25, context_length=16385, privacy_compliance=True, ) @@ -88,14 +87,13 @@ def batch_update(self, faq: FaqDTO): with batch_update_lock: with self.collection.batch.rate_limit(requests_per_minute=600) as batch: try: - # this needs to be working, otherwise its working just fine - # embed_chunk = self.llm_embedding.embed( - # f"{faq.questionTitle} : {faq.questionAnswer}" - # ) - - embed_chunk = [0.125, -1.179,-14.4,7.6,7.97] + embed_chunk = self.llm_embedding.embed( + f"{faq.question_title} : {faq.question_answer}" + ) faq_dict = faq.model_dump() + batch.add_object(properties=faq_dict, vector=embed_chunk) + for item in self.collection.iterator(): logging.info(item) diff --git a/app/pipeline/prompts/faq_retrieval_prompts.py b/app/pipeline/prompts/faq_retrieval_prompts.py new file mode 100644 index 00000000..1e5148d2 --- /dev/null +++ b/app/pipeline/prompts/faq_retrieval_prompts.py @@ -0,0 +1,10 @@ +faq_retriever_initial_prompt = """ +You write good and performant vector database queries, in particular for Weaviate, +from chat histories between an AI tutor and a student. +The query should be designed to retrieve context information from indexed faqs so the AI tutor +can use the context information to give a better answer. Apply accepted norms when querying vector databases. +Query the database so it returns answers for the latest student query. +A good vector database query is formulated in natural language, just like a student would ask a question. +It is not an instruction to the database, but a question to the database. +The chat history between the AI tutor and the student is provided to you in the next messages. +""" diff --git a/app/retrieval/faq_retrieval.py b/app/retrieval/faq_retrieval.py new file mode 100644 index 00000000..39512580 --- /dev/null +++ b/app/retrieval/faq_retrieval.py @@ -0,0 +1,390 @@ +import logging +from asyncio.log import logger +from typing import List + +from langsmith import traceable +from weaviate import WeaviateClient +from weaviate.classes.query import Filter + +from app.common.token_usage_dto import TokenUsageDTO +from app.common.PipelineEnum import PipelineEnum +from .lecture_retrieval import _add_last_four_messages_to_prompt +from ..common.pyris_message import PyrisMessage +from ..llm.langchain import IrisLangchainChatModel +from ..pipeline import Pipeline + +from app.llm import ( + BasicRequestHandler, + CompletionArguments, + CapabilityRequestHandler, + RequirementList, +) +from app.pipeline.shared.reranker_pipeline import RerankerPipeline +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import ( + ChatPromptTemplate, + SystemMessagePromptTemplate, +) + +from ..pipeline.prompts.faq_retrieval_prompts import faq_retriever_initial_prompt +from ..pipeline.prompts.lecture_retrieval_prompts import ( + assessment_prompt, + assessment_prompt_final, + rewrite_student_query_prompt, + write_hypothetical_answer_prompt, + rewrite_student_query_prompt_with_exercise_context, write_hypothetical_answer_with_exercise_context_prompt, +) +import concurrent.futures + +from ..vector_database.faq_schema import FaqSchema, init_faq_schema + + +def merge_retrieved_chunks( + basic_retrieved_faq_chunks, hyde_retrieved_faq_chunks +) -> List[dict]: + """ + Merge the retrieved chunks from the basic and hyde retrieval methods. This function ensures that for any + duplicate IDs, the properties from hyde_retrieved_faq_chunks will overwrite those from + basic_retrieved_faq_chunks. + """ + merged_chunks = {} + for chunk in basic_retrieved_faq_chunks: + merged_chunks[chunk["id"]] = chunk["properties"] + + for chunk in hyde_retrieved_faq_chunks: + merged_chunks[chunk["id"]] = chunk["properties"] + + return [properties for uuid, properties in merged_chunks.items()] + + +class FaqRetrieval(Pipeline): + """ + Class for retrieving faq data from the database. + """ + + tokens: List[TokenUsageDTO] + + def __init__(self, client: WeaviateClient, **kwargs): + super().__init__(implementation_id="faq_retrieval_pipeline") + request_handler = CapabilityRequestHandler( + requirements=RequirementList( + gpt_version_equivalent=4.25, + context_length=16385, + privacy_compliance=True, + ) + ) + completion_args = CompletionArguments(temperature=0, max_tokens=2000) + self.llm = IrisLangchainChatModel( + request_handler=request_handler, completion_args=completion_args + ) + self.llm_embedding = BasicRequestHandler("embedding-small") + self.pipeline = self.llm | StrOutputParser() + self.collection = init_faq_schema(client) + self.reranker_pipeline = RerankerPipeline() + self.tokens = [] + + @traceable(name="Full Faq Retrieval") + def __call__( + self, + chat_history: list[PyrisMessage], + student_query: str, + result_limit: int, + course_name: str = None, + course_id: int = None, + problem_statement: str = None, + exercise_title: str = None, + base_url: str = None, + + ) -> List[dict]: + """ + Retrieve faq data from the database. + """ + course_language = self.fetch_course_language(course_id) + + response, response_hyde = self.run_parallel_rewrite_tasks( + chat_history=chat_history, + student_query=student_query, + result_limit=result_limit, + course_language=course_language, + course_name=course_name, + course_id=course_id + ) + + logging.info(f"FAQ retrival response, {response}") + + basic_retrieved_faqs: list[dict[str, dict]] = [ + {"id": obj.uuid.int, "properties": obj.properties} + for obj in response.objects + ] + hyde_retrieved_faqs: list[dict[str, dict]] = [ + {"id": obj.uuid.int, "properties": obj.properties} + for obj in response_hyde.objects + ] + merged_chunks = merge_retrieved_chunks( + basic_retrieved_faqs, hyde_retrieved_faqs + ) + if len(merged_chunks) != 0: + selected_chunks_index = self.reranker_pipeline( + paragraphs=merged_chunks, query=student_query, chat_history=chat_history + ) + if selected_chunks_index: + return [merged_chunks[int(i)] for i in selected_chunks_index] + return [] + + @traceable(name="Basic Faq Retrieval") + def basic_faq_retrieval( + self, + chat_history: list[PyrisMessage], + student_query: str, + result_limit: int, + course_name: str = None, + course_id: int = None, + ) -> list[dict[str, dict]]: + """ + Basic retrieval for pipelines that need performance and fast answers. + """ + if not self.assess_question(chat_history, student_query): + return [] + + rewritten_query = self.rewrite_student_query( + chat_history, student_query, "course_language", course_name + ) + response = self.search_in_db( + query=rewritten_query, + hybrid_factor=0.9, + result_limit=result_limit, + course_id=course_id, + ) + + basic_retrieved_faq_chunks: list[dict[str, dict]] = [ + {"id": obj.uuid.int, "properties": obj.properties} + for obj in response.objects + ] + return basic_retrieved_faq_chunks + + @traceable(name="Retrieval: Question Assessment") + def assess_question( + self, chat_history: list[PyrisMessage], student_query: str + ) -> bool: + prompt = ChatPromptTemplate.from_messages( + [ + ("system", assessment_prompt), + ] + ) + prompt = _add_last_four_messages_to_prompt(prompt, chat_history) + prompt += ChatPromptTemplate.from_messages( + [ + ("user", student_query), + ] + ) + prompt += ChatPromptTemplate.from_messages( + [ + ("system", assessment_prompt_final), + ] + ) + + try: + response = (prompt | self.pipeline).invoke({}) + logger.info(f"Response from assessment pipeline: {response}") + return response == "YES" + except Exception as e: + raise e + + @traceable(name="Retrieval: Rewrite Student Query") + def rewrite_student_query( + self, + chat_history: list[PyrisMessage], + student_query: str, + course_language: str, + course_name: str, + ) -> str: + """ + Rewrite the student query. + """ + prompt = ChatPromptTemplate.from_messages( + [ + ("system", faq_retriever_initial_prompt), + ] + ) + prompt = _add_last_four_messages_to_prompt(prompt, chat_history) + prompt += SystemMessagePromptTemplate.from_template( + rewrite_student_query_prompt + ) + prompt_val = prompt.format_messages( + course_language=course_language, + course_name=course_name, + student_query=student_query, + ) + prompt = ChatPromptTemplate.from_messages(prompt_val) + try: + response = (prompt | self.pipeline).invoke({}) + token_usage = self.llm.tokens + token_usage.pipeline = PipelineEnum.IRIS_FAQ_RETRIEVAL_PIPELINE + self.tokens.append(self.llm.tokens) + logger.info(f"Response from exercise chat pipeline: {response}") + return response + except Exception as e: + raise e + + @traceable(name="Retrieval: Rewrite Elaborated Query") + def rewrite_elaborated_query( + self, + chat_history: list[PyrisMessage], + student_query: str, + course_language: str, + course_name: str, + ) -> str: + """ + Rewrite the student query to generate fitting faq content and embed it. + To extract more relevant content from the vector database. + """ + prompt = ChatPromptTemplate.from_messages( + [ + ("system", write_hypothetical_answer_prompt), + ] + ) + prompt = _add_last_four_messages_to_prompt(prompt, chat_history) + prompt += ChatPromptTemplate.from_messages( + [ + ("user", student_query), + ] + ) + prompt_val = prompt.format_messages( + course_language=course_language, + course_name=course_name, + ) + prompt = ChatPromptTemplate.from_messages(prompt_val) + try: + response = (prompt | self.pipeline).invoke({}) + token_usage = self.llm.tokens + token_usage.pipeline = PipelineEnum.IRIS_FAQ_RETRIEVAL_PIPELINE + self.tokens.append(self.llm.tokens) + logger.info(f"Response from faq retrival pipeline: {response}") + return response + except Exception as e: + raise e + + + @traceable(name="Retrieval: Search in DB") + def search_in_db( + self, + query: str, + hybrid_factor: float, + result_limit: int, + course_id: int = None, + ): + """ + Search the database for the given query. + """ + logger.info(f"Searching in the database for query: {query}") + # Initialize filter to None by default + filter_weaviate = None + + # Check if course_id is provided + if course_id: + # Create a filter for course_id + filter_weaviate = Filter.by_property(FaqSchema.COURSE_ID.value).equal( + course_id + ) + + + vec = self.llm_embedding.embed(query) + return_value = self.collection.query.hybrid( + query=query, + alpha=hybrid_factor, + vector=vec, + return_properties=[ + FaqSchema.COURSE_ID.value, + FaqSchema.FAQ_ID.value, + FaqSchema.QUESTION_TITLE.value, + FaqSchema.QUESTION_Answer.value, + ], + limit=result_limit, + filters=filter_weaviate, + ) + + logger.info(f"Search in the database response: {return_value}") + + return return_value + + @traceable(name="Retrieval: Run Parallel Rewrite Tasks") + def run_parallel_rewrite_tasks( + self, + chat_history: list[PyrisMessage], + student_query: str, + result_limit: int, + course_language: str, + course_name: str = None, + course_id: int = None, + ): + + with concurrent.futures.ThreadPoolExecutor() as executor: + # Schedule the rewrite tasks to run in parallel + rewritten_query_future = executor.submit( + self.rewrite_student_query, + chat_history, + student_query, + course_language, + course_name, + ) + hypothetical_answer_query_future = executor.submit( + self.rewrite_elaborated_query, + chat_history, + student_query, + course_language, + course_name, + ) + + # Get the results once both tasks are complete + rewritten_query = rewritten_query_future.result() + hypothetical_answer_query = hypothetical_answer_query_future.result() + + # Execute the database search tasks + with concurrent.futures.ThreadPoolExecutor() as executor: + response_future = executor.submit( + self.search_in_db, + query=rewritten_query, + hybrid_factor=0.9, + result_limit=result_limit, + course_id=course_id, + ) + response_hyde_future = executor.submit( + self.search_in_db, + query=hypothetical_answer_query, + hybrid_factor=0.9, + result_limit=result_limit, + course_id=course_id, + ) + + # Get the results once both tasks are complete + response = response_future.result() + response_hyde = response_hyde_future.result() + + return response, response_hyde + + def fetch_course_language(self, course_id): + """ + Fetch the language of the course based on the course ID. + If no specific language is set, it defaults to English. + """ + course_language = "english" + + if course_id: + # Fetch the first object that matches the course ID with the language property + result = self.collection.query.fetch_objects( + filters=Filter.by_property(FaqSchema.COURSE_ID.value).equal( + course_id + ), + limit=1, # We only need one object to check and retrieve the language + return_properties=[FaqSchema.COURSE_LANGUAGE.value], + ) + + # Check if the result has objects and retrieve the language + if result.objects: + fetched_language = result.objects[0].properties.get( + FaqSchema.COURSE_LANGUAGE.value + ) + if fetched_language: + course_language = fetched_language + + return course_language diff --git a/app/vector_database/database.py b/app/vector_database/database.py index e4dedcd0..5c3ed7cf 100644 --- a/app/vector_database/database.py +++ b/app/vector_database/database.py @@ -1,5 +1,7 @@ import logging import weaviate + +from .faq_schema import init_faq_schema from .lecture_schema import init_lecture_schema from weaviate.classes.query import Filter from app.config import settings @@ -27,6 +29,7 @@ def __init__(self): logger.info("Weaviate client initialized") self.client = VectorDatabase._client_instance self.lectures = init_lecture_schema(self.client) + self.faqs = init_faq_schema(self.client) def delete_collection(self, collection_name): """ diff --git a/app/web/routers/webhooks.py b/app/web/routers/webhooks.py index d11db586..0c0f496d 100644 --- a/app/web/routers/webhooks.py +++ b/app/web/routers/webhooks.py @@ -172,7 +172,6 @@ def faq_deletion_webhook(dto: FaqDeletionExecutionDto): """ Webhook endpoint to trigger the faq deletion pipeline """ - logging.info(dto) logging.info("Starting faq deletion") thread = Thread(target=run_faq_delete_pipeline_worker, args=(dto,)) thread.start() diff --git a/app/web/status/faq_ingestion_status_callback.py b/app/web/status/faq_ingestion_status_callback.py index ad15c9c1..e7472f77 100644 --- a/app/web/status/faq_ingestion_status_callback.py +++ b/app/web/status/faq_ingestion_status_callback.py @@ -21,8 +21,7 @@ def __init__( initial_stages: List[StageDTO] = None, faq_id: int = None, ): - url = f"{base_url}/api/public/pyris/webhooks/ingestion/runs/{run_id}/status" - + url = f"{base_url}/api/public/pyris/webhooks/ingestion/faqs/runs/{run_id}/status" current_stage_index = len(initial_stages) if initial_stages else 0 stages = initial_stages or [] stages += [ From 933c045ef29a6f54767f75ee9771ffdd46046d63 Mon Sep 17 00:00:00 2001 From: Tim Cremer Date: Thu, 26 Dec 2024 17:10:31 +0100 Subject: [PATCH 03/19] Further faq retrieval. --- app/pipeline/chat/course_chat_pipeline.py | 23 ++++++++------ app/pipeline/shared/citation_pipeline.py | 38 ++++++++++++++++++++--- app/retrieval/faq_retrieval.py | 19 +++++++----- app/vector_database/faq_schema.py | 4 +-- 4 files changed, 60 insertions(+), 24 deletions(-) diff --git a/app/pipeline/chat/course_chat_pipeline.py b/app/pipeline/chat/course_chat_pipeline.py index feda706e..8adce6cd 100644 --- a/app/pipeline/chat/course_chat_pipeline.py +++ b/app/pipeline/chat/course_chat_pipeline.py @@ -20,7 +20,7 @@ InteractionSuggestionPipeline, ) from .lecture_chat_pipeline import LectureChatPipeline -from ..shared.citation_pipeline import CitationPipeline +from ..shared.citation_pipeline import CitationPipeline, InformationType from ..shared.utils import generate_structured_tools_from_functions from ...common.message_converters import convert_iris_message_to_langchain_message from ...common.pyris_message import PyrisMessage @@ -82,6 +82,7 @@ class CourseChatPipeline(Pipeline): variant: str event: str | None retrieved_paragraphs: List[dict] = None + retrieved_faqs: List[dict] = None def __init__( self, @@ -299,14 +300,14 @@ def lecture_content_retrieval() -> str: def faq_content_retrieval() -> str: """ Retrieve content from indexed faqs. + Use this if you think the question is a common question, or it can be useful to answer the student's question + with a faq, or if the student explicitly asks an organizational question about the course This will run a RAG retrieval based on the chat history on the indexed faqs and return the - most relevant paragraphs. - Use this if you think it can be useful to answer the student's question with a faq, or if the student explicitly asks - an organizational question about the course. + most relevant faqs. Only use this once. """ self.callback.in_progress("Retrieving faq content ...") - self.retrieved_paragraphs = self.faq_retriever( + self.retrieved_faqs = self.faq_retriever( chat_history=history, student_query=query.contents[0].text_content, result_limit=5, @@ -317,12 +318,12 @@ def faq_content_retrieval() -> str: result = "" for faq in self.retrieved_faqs: - res = "FAQ Title: {}, FAQ Answer: {}, ID: {}".format( + res = ("FAQ Question: {}, FAQ Answer: {}").format( faq.get(FaqSchema.QUESTION_TITLE.value), - faq.get(FaqSchema.QUESTION_Answer.value), - faq.get(FaqSchema.FAQ_ID.value), + faq.get(FaqSchema.QUESTION_ANSWER.value), ) result += res + logging.info(f"result from faq retrieval: {result}") return result if dto.user.id % 3 < 2: @@ -446,9 +447,13 @@ def faq_content_retrieval() -> str: if self.retrieved_paragraphs: self.callback.in_progress("Augmenting response ...") - out = self.citation_pipeline(self.retrieved_paragraphs, out) + out = self.citation_pipeline(self.retrieved_paragraphs, out, InformationType.PARAGRAPHS) self.tokens.extend(self.citation_pipeline.tokens) + if self.retrieved_faqs: + self.callback.in_progress("Augmenting response ...") + out = self.citation_pipeline(self.retrieved_faqs, out, InformationType.FAQS) + self.callback.done("Response created", final_result=out, tokens=self.tokens) # try: diff --git a/app/pipeline/shared/citation_pipeline.py b/app/pipeline/shared/citation_pipeline.py index 22e13360..bf23450f 100644 --- a/app/pipeline/shared/citation_pipeline.py +++ b/app/pipeline/shared/citation_pipeline.py @@ -1,5 +1,7 @@ +import logging import os from asyncio.log import logger +from enum import Enum from typing import List, Union from langchain_core.output_parsers import StrOutputParser @@ -10,9 +12,14 @@ from app.common.PipelineEnum import PipelineEnum from app.llm.langchain import IrisLangchainChatModel from app.pipeline import Pipeline +from app.vector_database.faq_schema import FaqSchema from app.vector_database.lecture_schema import LectureSchema +class InformationType(str, Enum): + PARAGRAPHS = "PARAGRAPHS" + FAQS = "FAQS" + class CitationPipeline(Pipeline): """A generic reranker pipeline that can be used to rerank a list of documents based on a question""" @@ -47,7 +54,8 @@ def __repr__(self): def __str__(self): return f"{self.__class__.__name__}(llm={self.llm})" - def create_formatted_string(self, paragraphs): + + def create_formatted_lecture_string(self, paragraphs): """ Create a formatted string from the data """ @@ -64,19 +72,40 @@ def create_formatted_string(self, paragraphs): return formatted_string.replace("{", "{{").replace("}", "}}") + def create_formatted_faq_string(self, faqs): + """ + Create a formatted string from the data + """ + formatted_string = "" + for i, faq in enumerate(faqs): + faq = "Question: {}, Answer: {}".format( + faq.get(FaqSchema.QUESTION_ANSWER.value), + faq.get(FaqSchema.QUESTION_TITLE.value), + ) + formatted_string += faq + + return formatted_string.replace("{", "{{").replace("}", "}}") + + def __call__( self, - paragraphs: Union[List[dict], List[str]], + information: Union[List[dict], List[str]], answer: str, + information_type: InformationType = InformationType.PARAGRAPHS, **kwargs, ) -> str: """ Runs the pipeline - :param paragraphs: List of paragraphs which can be list of dicts or list of strings + :param information: List of information which can be list of dicts or list of strings. Used to augment the response :param query: The query :return: Selected file content """ - paras = self.create_formatted_string(paragraphs) + paras = "" + + if information_type == InformationType.FAQS: + paras = self.create_formatted_faq_string(information) + if information_type == InformationType.PARAGRAPHS: + paras = self.create_formatted_lecture_string(information) try: self.default_prompt = PromptTemplate( @@ -89,7 +118,6 @@ def __call__( self._append_tokens(self.llm.tokens, PipelineEnum.IRIS_CITATION_PIPELINE) if response == "!NONE!": return answer - print(response) return response except Exception as e: logger.error("citation pipeline failed", e) diff --git a/app/retrieval/faq_retrieval.py b/app/retrieval/faq_retrieval.py index 39512580..64af0e39 100644 --- a/app/retrieval/faq_retrieval.py +++ b/app/retrieval/faq_retrieval.py @@ -123,13 +123,16 @@ def __call__( merged_chunks = merge_retrieved_chunks( basic_retrieved_faqs, hyde_retrieved_faqs ) - if len(merged_chunks) != 0: - selected_chunks_index = self.reranker_pipeline( - paragraphs=merged_chunks, query=student_query, chat_history=chat_history - ) - if selected_chunks_index: - return [merged_chunks[int(i)] for i in selected_chunks_index] - return [] + + logging.info(f"merged_chunks, {merged_chunks}") + return merged_chunks + #if len(merged_chunks) != 0: + # selected_chunks_index = self.reranker_pipeline( + # paragraphs=merged_chunks, query=student_query, chat_history=chat_history + # ) + # if selected_chunks_index: + # return [merged_chunks[int(i)] for i in selected_chunks_index] + #return [] @traceable(name="Basic Faq Retrieval") def basic_faq_retrieval( @@ -297,7 +300,7 @@ def search_in_db( FaqSchema.COURSE_ID.value, FaqSchema.FAQ_ID.value, FaqSchema.QUESTION_TITLE.value, - FaqSchema.QUESTION_Answer.value, + FaqSchema.QUESTION_ANSWER.value, ], limit=result_limit, filters=filter_weaviate, diff --git a/app/vector_database/faq_schema.py b/app/vector_database/faq_schema.py index a81f3e40..325e8f97 100644 --- a/app/vector_database/faq_schema.py +++ b/app/vector_database/faq_schema.py @@ -19,7 +19,7 @@ class FaqSchema(Enum): COURSE_ID = "course_id" FAQ_ID = "faq_id" QUESTION_TITLE = "question_title" - QUESTION_Answer = "question_answer" + QUESTION_ANSWER = "question_answer" def init_faq_schema(client: WeaviateClient) -> Collection: @@ -88,7 +88,7 @@ def init_faq_schema(client: WeaviateClient) -> Collection: data_type=DataType.TEXT, ), Property( - name=FaqSchema.QUESTION_Answer.value, + name=FaqSchema.QUESTION_ANSWER.value, description="The answer of the faq", data_type=DataType.TEXT, ), From 54f5167b7780180610660a0dc5b0b25f56136e91 Mon Sep 17 00:00:00 2001 From: Tim Cremer Date: Fri, 27 Dec 2024 16:06:35 +0100 Subject: [PATCH 04/19] Working FAQ retrival --- app/pipeline/chat/course_chat_pipeline.py | 15 ++++++++++----- app/pipeline/prompts/faq_retrieval_prompts.py | 9 +++++++++ app/retrieval/faq_retrieval.py | 12 +++++++----- app/web/status/faq_ingestion_status_callback.py | 1 + 4 files changed, 27 insertions(+), 10 deletions(-) diff --git a/app/pipeline/chat/course_chat_pipeline.py b/app/pipeline/chat/course_chat_pipeline.py index 8adce6cd..e16eb2b2 100644 --- a/app/pipeline/chat/course_chat_pipeline.py +++ b/app/pipeline/chat/course_chat_pipeline.py @@ -10,7 +10,7 @@ from langchain_core.messages import SystemMessage from langchain_core.output_parsers import JsonOutputParser from langchain_core.prompts import ( - ChatPromptTemplate, + ChatPromptTemplate, SystemMessagePromptTemplate, ) from langchain_core.runnables import Runnable from langsmith import traceable @@ -301,9 +301,10 @@ def faq_content_retrieval() -> str: """ Retrieve content from indexed faqs. Use this if you think the question is a common question, or it can be useful to answer the student's question - with a faq, or if the student explicitly asks an organizational question about the course + with a faq, or if the student explicitly asks an organizational question about the course. This will run a RAG retrieval based on the chat history on the indexed faqs and return the - most relevant faqs. + most relevant faq. Every FAQ has the following format: [FAQ ID: ..., FAQ Question: ..., FAQ Answer: ...] + You will get the question and answer of the faq. Make sure to answer the question with the answer of your selected FAQ. Only use this once. """ self.callback.in_progress("Retrieving faq content ...") @@ -318,7 +319,8 @@ def faq_content_retrieval() -> str: result = "" for faq in self.retrieved_faqs: - res = ("FAQ Question: {}, FAQ Answer: {}").format( + res = "[FAQ ID: {}, FAQ Question: {}, FAQ Answer: {}]".format( + faq.get(FaqSchema.FAQ_ID.value), faq.get(FaqSchema.QUESTION_TITLE.value), faq.get(FaqSchema.QUESTION_ANSWER.value), ) @@ -427,6 +429,8 @@ def faq_content_retrieval() -> str: if self.should_allow_faq_tool(dto.course.id): tool_list.append(faq_content_retrieval) + + tools = generate_structured_tools_from_functions(tool_list) # No idea why we need this extra contrary to exercise chat agent in this case, but solves the issue. params.update({"tools": tools}) @@ -453,7 +457,6 @@ def faq_content_retrieval() -> str: if self.retrieved_faqs: self.callback.in_progress("Augmenting response ...") out = self.citation_pipeline(self.retrieved_faqs, out, InformationType.FAQS) - self.callback.done("Response created", final_result=out, tokens=self.tokens) # try: @@ -530,3 +533,5 @@ def datetime_to_string(dt: Optional[datetime]) -> str: return "No date provided" else: return dt.strftime("%Y-%m-%d %H:%M:%S") + + diff --git a/app/pipeline/prompts/faq_retrieval_prompts.py b/app/pipeline/prompts/faq_retrieval_prompts.py index 1e5148d2..0b272f0a 100644 --- a/app/pipeline/prompts/faq_retrieval_prompts.py +++ b/app/pipeline/prompts/faq_retrieval_prompts.py @@ -8,3 +8,12 @@ It is not an instruction to the database, but a question to the database. The chat history between the AI tutor and the student is provided to you in the next messages. """ + +write_hypothetical_answer_prompt = """ +A student has sent a query in the context the course {course_name}. +The chat history between the AI tutor and the student is provided to you in the next messages. +Please provide a response in {course_language}. +You should create a response that looks like a faq answer. +Craft your response to closely reflect the style and content of typical university lecture materials. +Do not exceed 350 words. Add keywords and phrases that are relevant to student intent. +""" diff --git a/app/retrieval/faq_retrieval.py b/app/retrieval/faq_retrieval.py index 64af0e39..20af67e2 100644 --- a/app/retrieval/faq_retrieval.py +++ b/app/retrieval/faq_retrieval.py @@ -26,13 +26,11 @@ SystemMessagePromptTemplate, ) -from ..pipeline.prompts.faq_retrieval_prompts import faq_retriever_initial_prompt +from ..pipeline.prompts.faq_retrieval_prompts import faq_retriever_initial_prompt, write_hypothetical_answer_prompt from ..pipeline.prompts.lecture_retrieval_prompts import ( assessment_prompt, assessment_prompt_final, rewrite_student_query_prompt, - write_hypothetical_answer_prompt, - rewrite_student_query_prompt_with_exercise_context, write_hypothetical_answer_with_exercise_context_prompt, ) import concurrent.futures @@ -126,6 +124,7 @@ def __call__( logging.info(f"merged_chunks, {merged_chunks}") return merged_chunks + #if len(merged_chunks) != 0: # selected_chunks_index = self.reranker_pipeline( # paragraphs=merged_chunks, query=student_query, chat_history=chat_history @@ -224,7 +223,7 @@ def rewrite_student_query( token_usage = self.llm.tokens token_usage.pipeline = PipelineEnum.IRIS_FAQ_RETRIEVAL_PIPELINE self.tokens.append(self.llm.tokens) - logger.info(f"Response from exercise chat pipeline: {response}") + logger.info(f"Response from faq retrieval pipeline: {response}") return response except Exception as e: raise e @@ -257,6 +256,7 @@ def rewrite_elaborated_query( course_name=course_name, ) prompt = ChatPromptTemplate.from_messages(prompt_val) + logging.info(f"Prompt for elaborated query: {prompt}") try: response = (prompt | self.pipeline).invoke({}) token_usage = self.llm.tokens @@ -340,7 +340,9 @@ def run_parallel_rewrite_tasks( # Get the results once both tasks are complete rewritten_query = rewritten_query_future.result() + logging.info(f"Rewritten query: {rewritten_query}") hypothetical_answer_query = hypothetical_answer_query_future.result() + logging.info(f"Hypothetical answer query: {hypothetical_answer_query}") # Execute the database search tasks with concurrent.futures.ThreadPoolExecutor() as executor: @@ -390,4 +392,4 @@ def fetch_course_language(self, course_id): if fetched_language: course_language = fetched_language - return course_language + return course_language \ No newline at end of file diff --git a/app/web/status/faq_ingestion_status_callback.py b/app/web/status/faq_ingestion_status_callback.py index e7472f77..b60322fd 100644 --- a/app/web/status/faq_ingestion_status_callback.py +++ b/app/web/status/faq_ingestion_status_callback.py @@ -22,6 +22,7 @@ def __init__( faq_id: int = None, ): url = f"{base_url}/api/public/pyris/webhooks/ingestion/faqs/runs/{run_id}/status" + current_stage_index = len(initial_stages) if initial_stages else 0 stages = initial_stages or [] stages += [ From 3b12cd6ab3472dedf0bfd96d8e58d086706bf15c Mon Sep 17 00:00:00 2001 From: Tim Cremer Date: Mon, 6 Jan 2025 17:17:18 +0100 Subject: [PATCH 05/19] Removed logging --- app/pipeline/chat/course_chat_pipeline.py | 1 - app/pipeline/faq_ingestion_pipeline.py | 3 --- 2 files changed, 4 deletions(-) diff --git a/app/pipeline/chat/course_chat_pipeline.py b/app/pipeline/chat/course_chat_pipeline.py index e16eb2b2..89755e5a 100644 --- a/app/pipeline/chat/course_chat_pipeline.py +++ b/app/pipeline/chat/course_chat_pipeline.py @@ -325,7 +325,6 @@ def faq_content_retrieval() -> str: faq.get(FaqSchema.QUESTION_ANSWER.value), ) result += res - logging.info(f"result from faq retrieval: {result}") return result if dto.user.id % 3 < 2: diff --git a/app/pipeline/faq_ingestion_pipeline.py b/app/pipeline/faq_ingestion_pipeline.py index 28d9fc5c..074485cd 100644 --- a/app/pipeline/faq_ingestion_pipeline.py +++ b/app/pipeline/faq_ingestion_pipeline.py @@ -94,9 +94,6 @@ def batch_update(self, faq: FaqDTO): batch.add_object(properties=faq_dict, vector=embed_chunk) - for item in self.collection.iterator(): - logging.info(item) - except Exception as e: logger.error(f"Error updating faq: {e}") From 101747724dee7793bdfcb437b8902cbc532cead2 Mon Sep 17 00:00:00 2001 From: Tim Cremer Date: Wed, 8 Jan 2025 15:14:47 +0100 Subject: [PATCH 06/19] Removed logging, added Links for FAQ answer, updated prompts --- app/pipeline/chat/course_chat_pipeline.py | 17 ++++++----- app/pipeline/prompts/faq_citation_prompt.txt | 32 ++++++++++++++++++++ app/pipeline/shared/citation_pipeline.py | 18 ++++++++--- app/retrieval/faq_retrieval.py | 15 +-------- 4 files changed, 55 insertions(+), 27 deletions(-) create mode 100644 app/pipeline/prompts/faq_citation_prompt.txt diff --git a/app/pipeline/chat/course_chat_pipeline.py b/app/pipeline/chat/course_chat_pipeline.py index 89755e5a..b53ffa8e 100644 --- a/app/pipeline/chat/course_chat_pipeline.py +++ b/app/pipeline/chat/course_chat_pipeline.py @@ -299,13 +299,14 @@ def lecture_content_retrieval() -> str: def faq_content_retrieval() -> str: """ - Retrieve content from indexed faqs. - Use this if you think the question is a common question, or it can be useful to answer the student's question - with a faq, or if the student explicitly asks an organizational question about the course. - This will run a RAG retrieval based on the chat history on the indexed faqs and return the - most relevant faq. Every FAQ has the following format: [FAQ ID: ..., FAQ Question: ..., FAQ Answer: ...] - You will get the question and answer of the faq. Make sure to answer the question with the answer of your selected FAQ. - Only use this once. + Use this tool to retrieve information from indexed FAQs. + It is suitable when no other tool fits, you think it is a common question or the question is frequently asked, + or the question could be effectively answered by an FAQ. Also use this if the question is explicitly organizational and course-related. + An organizational question about the course might be "What is the course structure?" or "How do I enroll?" or exam related content like "When is the exam". + The tool performs a RAG retrieval based on the chat history to find the most relevant FAQs. Each FAQ follows this format: + FAQ ID, FAQ Question, FAQ Answer. + Respond to the query concisely and solely using the answer from the relevant FAQs. This tool should only be used once per query. + """ self.callback.in_progress("Retrieving faq content ...") self.retrieved_faqs = self.faq_retriever( @@ -455,7 +456,7 @@ def faq_content_retrieval() -> str: if self.retrieved_faqs: self.callback.in_progress("Augmenting response ...") - out = self.citation_pipeline(self.retrieved_faqs, out, InformationType.FAQS) + out = self.citation_pipeline(self.retrieved_faqs, out, InformationType.FAQS, base_url=dto.settings.artemis_base_url) self.callback.done("Response created", final_result=out, tokens=self.tokens) # try: diff --git a/app/pipeline/prompts/faq_citation_prompt.txt b/app/pipeline/prompts/faq_citation_prompt.txt new file mode 100644 index 00000000..681bb05c --- /dev/null +++ b/app/pipeline/prompts/faq_citation_prompt.txt @@ -0,0 +1,32 @@ +In the paragraphs below you are provided with an answer to a question. Underneath the answer you will find the faqs that the answer was based on. +Add citations of the faqs to the answer. Cite the faqs in brackets after the sentence where the information is used in the answer. +At the end of the answer, list each source with its corresponding number and provide the FAQ Question title, and a clickable link in this format: [1] "FAQ Question title". +Do not include the actual faqs, only the citations at the end. +Please do not use the FAQ ID as the citation number, instead, use the order of the citations in the answer. +Only include the citations of the faqs that are relevant to the answer. +If the answer actually does not contain any information from the faqs, please do not include any citations and return '!NONE!'. +But if the answer contains information from the paragraphs, ALWAYS include citations. + +Here is an example how to rewrite the answer with citations (ONLY ADD CITATION IF THE PROVIDED FAQS ARE RELEVANT TO THE ANSWER): +" +Lorem ipsum dolor sit amet, consectetur adipiscing elit [1]. Ded do eiusmod tempor incididunt ut labore et dolore magna aliqua [2]. + +[1] FAQ question title 1. +[2] FAQ question title 2. +" + +Note: If there is no link available, please do not include the link in the citation. For example, if citation 1 does not have a link, it should look like this: +[1] "FAQ question title" +but if citation 2 has a link, it should look like this: +[2] "FAQ question title" + +Here are the answer and the faqs: + +Answer without citations: +{Answer} + +Faqs with their FAQ ID, CourseId, FAQ Question title and FAQ Question Answer and the Link to the FAQ: +{Paragraphs} + +Answer with citations (ensure empty line between the message and the citations): +If the answer actually does not contain any information from the paragraphs, please do not include any citations and return '!NONE!'. diff --git a/app/pipeline/shared/citation_pipeline.py b/app/pipeline/shared/citation_pipeline.py index bf23450f..411f6b57 100644 --- a/app/pipeline/shared/citation_pipeline.py +++ b/app/pipeline/shared/citation_pipeline.py @@ -44,7 +44,10 @@ def __init__(self): dirname = os.path.dirname(__file__) prompt_file_path = os.path.join(dirname, "..", "prompts", "citation_prompt.txt") with open(prompt_file_path, "r") as file: - self.prompt_str = file.read() + self.lecture_prompt_str = file.read() + prompt_file_path = os.path.join(dirname, "..", "prompts", "faq_citation_prompt.txt") + with open(prompt_file_path, "r") as file: + self.faq_prompt_str = file.read() self.pipeline = self.llm | StrOutputParser() self.tokens = [] @@ -72,15 +75,18 @@ def create_formatted_lecture_string(self, paragraphs): return formatted_string.replace("{", "{{").replace("}", "}}") - def create_formatted_faq_string(self, faqs): + def create_formatted_faq_string(self, faqs, base_url): """ Create a formatted string from the data """ formatted_string = "" for i, faq in enumerate(faqs): - faq = "Question: {}, Answer: {}".format( - faq.get(FaqSchema.QUESTION_ANSWER.value), + faq = "FAQ ID {}, CourseId {} , FAQ Question title {} and FAQ Question Answer {} and FAQ link {}".format( + faq.get(FaqSchema.FAQ_ID.value), + faq.get(FaqSchema.COURSE_ID.value), faq.get(FaqSchema.QUESTION_TITLE.value), + faq.get(FaqSchema.QUESTION_ANSWER.value), + f"{base_url}/courses/{faq.get(FaqSchema.COURSE_ID.value)}/faq/?faqId={faq.get(FaqSchema.FAQ_ID.value)}" ) formatted_string += faq @@ -103,9 +109,11 @@ def __call__( paras = "" if information_type == InformationType.FAQS: - paras = self.create_formatted_faq_string(information) + paras = self.create_formatted_faq_string(information, kwargs.get("base_url")) + self.prompt_str = self.faq_prompt_str if information_type == InformationType.PARAGRAPHS: paras = self.create_formatted_lecture_string(information) + self.prompt_str = self.lecture_prompt_str try: self.default_prompt = PromptTemplate( diff --git a/app/retrieval/faq_retrieval.py b/app/retrieval/faq_retrieval.py index 20af67e2..d689286f 100644 --- a/app/retrieval/faq_retrieval.py +++ b/app/retrieval/faq_retrieval.py @@ -125,13 +125,6 @@ def __call__( logging.info(f"merged_chunks, {merged_chunks}") return merged_chunks - #if len(merged_chunks) != 0: - # selected_chunks_index = self.reranker_pipeline( - # paragraphs=merged_chunks, query=student_query, chat_history=chat_history - # ) - # if selected_chunks_index: - # return [merged_chunks[int(i)] for i in selected_chunks_index] - #return [] @traceable(name="Basic Faq Retrieval") def basic_faq_retrieval( @@ -223,7 +216,6 @@ def rewrite_student_query( token_usage = self.llm.tokens token_usage.pipeline = PipelineEnum.IRIS_FAQ_RETRIEVAL_PIPELINE self.tokens.append(self.llm.tokens) - logger.info(f"Response from faq retrieval pipeline: {response}") return response except Exception as e: raise e @@ -256,13 +248,12 @@ def rewrite_elaborated_query( course_name=course_name, ) prompt = ChatPromptTemplate.from_messages(prompt_val) - logging.info(f"Prompt for elaborated query: {prompt}") + try: response = (prompt | self.pipeline).invoke({}) token_usage = self.llm.tokens token_usage.pipeline = PipelineEnum.IRIS_FAQ_RETRIEVAL_PIPELINE self.tokens.append(self.llm.tokens) - logger.info(f"Response from faq retrival pipeline: {response}") return response except Exception as e: raise e @@ -306,8 +297,6 @@ def search_in_db( filters=filter_weaviate, ) - logger.info(f"Search in the database response: {return_value}") - return return_value @traceable(name="Retrieval: Run Parallel Rewrite Tasks") @@ -340,9 +329,7 @@ def run_parallel_rewrite_tasks( # Get the results once both tasks are complete rewritten_query = rewritten_query_future.result() - logging.info(f"Rewritten query: {rewritten_query}") hypothetical_answer_query = hypothetical_answer_query_future.result() - logging.info(f"Hypothetical answer query: {hypothetical_answer_query}") # Execute the database search tasks with concurrent.futures.ThreadPoolExecutor() as executor: From e8f611c24c2417e909ede96915ef14132e56c192 Mon Sep 17 00:00:00 2001 From: Tim Cremer Date: Mon, 13 Jan 2025 10:56:52 +0100 Subject: [PATCH 07/19] Added language --- app/retrieval/faq_retrieval.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/app/retrieval/faq_retrieval.py b/app/retrieval/faq_retrieval.py index d689286f..830414a1 100644 --- a/app/retrieval/faq_retrieval.py +++ b/app/retrieval/faq_retrieval.py @@ -1,5 +1,4 @@ import logging -from asyncio.log import logger from typing import List from langsmith import traceable @@ -37,6 +36,9 @@ from ..vector_database.faq_schema import FaqSchema, init_faq_schema +logger = logging.getLogger(__name__) + + def merge_retrieved_chunks( basic_retrieved_faq_chunks, hyde_retrieved_faq_chunks ) -> List[dict]: @@ -134,6 +136,7 @@ def basic_faq_retrieval( result_limit: int, course_name: str = None, course_id: int = None, + course_language: str = None, ) -> list[dict[str, dict]]: """ Basic retrieval for pipelines that need performance and fast answers. @@ -142,7 +145,7 @@ def basic_faq_retrieval( return [] rewritten_query = self.rewrite_student_query( - chat_history, student_query, "course_language", course_name + chat_history, student_query, course_language, course_name ) response = self.search_in_db( query=rewritten_query, From edfbfa245fb8a5365f4a1c25f0cc5065c55d578f Mon Sep 17 00:00:00 2001 From: Tim Cremer Date: Mon, 13 Jan 2025 13:41:46 +0100 Subject: [PATCH 08/19] Increased faq limit --- app/pipeline/chat/course_chat_pipeline.py | 2 +- app/web/routers/webhooks.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/app/pipeline/chat/course_chat_pipeline.py b/app/pipeline/chat/course_chat_pipeline.py index b53ffa8e..eab549ec 100644 --- a/app/pipeline/chat/course_chat_pipeline.py +++ b/app/pipeline/chat/course_chat_pipeline.py @@ -312,7 +312,7 @@ def faq_content_retrieval() -> str: self.retrieved_faqs = self.faq_retriever( chat_history=history, student_query=query.contents[0].text_content, - result_limit=5, + result_limit=10, course_name=dto.course.name, course_id=dto.course.id, base_url=dto.settings.artemis_base_url, diff --git a/app/web/routers/webhooks.py b/app/web/routers/webhooks.py index 0c0f496d..79ac4522 100644 --- a/app/web/routers/webhooks.py +++ b/app/web/routers/webhooks.py @@ -99,7 +99,7 @@ def run_faq_update_pipeline_worker(dto: FaqIngestionPipelineExecutionDto): semaphore.release() -def run_faq_delete_pipeline_worker(dto: IngestionPipelineExecutionDto): +def run_faq_delete_pipeline_worker(dto: FaqDeletionExecutionDto): """ Run the faq deletion in a separate thread """ From 27abf69deb43d059d93af32f9b05c6caba8850fe Mon Sep 17 00:00:00 2001 From: Tim Cremer Date: Sat, 18 Jan 2025 21:32:47 +0100 Subject: [PATCH 09/19] Reformat --- app/domain/data/faq_dto.py | 3 -- .../ingestion/deletionPipelineExecutionDto.py | 3 +- .../ingestion_pipeline_execution_dto.py | 4 +-- app/pipeline/chat/course_chat_pipeline.py | 28 +++++++++++-------- .../chat/exercise_chat_agent_pipeline.py | 4 ++- app/pipeline/faq_ingestion_pipeline.py | 13 +++------ app/pipeline/shared/citation_pipeline.py | 16 +++++++---- app/retrieval/faq_retrieval.py | 17 +++++------ app/web/routers/webhooks.py | 19 ++++++------- .../status/faq_ingestion_status_callback.py | 4 ++- 10 files changed, 56 insertions(+), 55 deletions(-) diff --git a/app/domain/data/faq_dto.py b/app/domain/data/faq_dto.py index e68716af..12d3fd7a 100644 --- a/app/domain/data/faq_dto.py +++ b/app/domain/data/faq_dto.py @@ -8,6 +8,3 @@ class FaqDTO(BaseModel): question_answer: str = Field(alias="questionAnswer") course_name: str = Field(default="", alias="courseName") course_description: str = Field(default="", alias="courseDescription") - - - diff --git a/app/domain/ingestion/deletionPipelineExecutionDto.py b/app/domain/ingestion/deletionPipelineExecutionDto.py index 84445616..0824f1d6 100644 --- a/app/domain/ingestion/deletionPipelineExecutionDto.py +++ b/app/domain/ingestion/deletionPipelineExecutionDto.py @@ -15,9 +15,10 @@ class LecturesDeletionExecutionDto(PipelineExecutionDTO): default=None, alias="initialStages" ) + class FaqDeletionExecutionDto(PipelineExecutionDTO): faq: FaqDTO = Field(..., alias="pyrisFaqWebhookDTO") settings: Optional[PipelineExecutionSettingsDTO] initial_stages: Optional[List[StageDTO]] = Field( default=None, alias="initialStages" - ) \ No newline at end of file + ) diff --git a/app/domain/ingestion/ingestion_pipeline_execution_dto.py b/app/domain/ingestion/ingestion_pipeline_execution_dto.py index 213be158..d97fda45 100644 --- a/app/domain/ingestion/ingestion_pipeline_execution_dto.py +++ b/app/domain/ingestion/ingestion_pipeline_execution_dto.py @@ -15,10 +15,10 @@ class IngestionPipelineExecutionDto(PipelineExecutionDTO): default=None, alias="initialStages" ) + class FaqIngestionPipelineExecutionDto(PipelineExecutionDTO): faq: FaqDTO = Field(..., alias="pyrisFaqWebhookDTO") settings: Optional[PipelineExecutionSettingsDTO] initial_stages: Optional[List[StageDTO]] = Field( default=None, alias="initialStages" - ) - + ) diff --git a/app/pipeline/chat/course_chat_pipeline.py b/app/pipeline/chat/course_chat_pipeline.py index 10e05390..c07923e0 100644 --- a/app/pipeline/chat/course_chat_pipeline.py +++ b/app/pipeline/chat/course_chat_pipeline.py @@ -10,7 +10,8 @@ from langchain_core.messages import SystemMessage from langchain_core.output_parsers import JsonOutputParser from langchain_core.prompts import ( - ChatPromptTemplate, SystemMessagePromptTemplate, + ChatPromptTemplate, + SystemMessagePromptTemplate, ) from langchain_core.runnables import Runnable from langsmith import traceable @@ -103,14 +104,16 @@ def __init__( requirements=RequirementList( gpt_version_equivalent=4.5, ) - ), completion_args=completion_args + ), + completion_args=completion_args, ) self.llm_small = IrisLangchainChatModel( request_handler=CapabilityRequestHandler( requirements=RequirementList( gpt_version_equivalent=4.25, ) - ), completion_args=completion_args + ), + completion_args=completion_args, ) self.callback = callback @@ -436,8 +439,6 @@ def faq_content_retrieval() -> str: if self.should_allow_faq_tool(dto.course.id): tool_list.append(faq_content_retrieval) - - tools = generate_structured_tools_from_functions(tool_list) # No idea why we need this extra contrary to exercise chat agent in this case, but solves the issue. params.update({"tools": tools}) @@ -458,12 +459,19 @@ def faq_content_retrieval() -> str: if self.retrieved_paragraphs: self.callback.in_progress("Augmenting response ...") - out = self.citation_pipeline(self.retrieved_paragraphs, out, InformationType.PARAGRAPHS) + out = self.citation_pipeline( + self.retrieved_paragraphs, out, InformationType.PARAGRAPHS + ) self.tokens.extend(self.citation_pipeline.tokens) if self.retrieved_faqs: self.callback.in_progress("Augmenting response ...") - out = self.citation_pipeline(self.retrieved_faqs, out, InformationType.FAQS, base_url=dto.settings.artemis_base_url) + out = self.citation_pipeline( + self.retrieved_faqs, + out, + InformationType.FAQS, + base_url=dto.settings.artemis_base_url, + ) self.callback.done("Response created", final_result=out, tokens=self.tokens) # try: @@ -525,9 +533,7 @@ def should_allow_faq_tool(self, course_id: int) -> bool: if course_id: # Fetch the first object that matches the course ID with the language property result = self.db.faqs.query.fetch_objects( - filters=Filter.by_property(FaqSchema.COURSE_ID.value).equal( - course_id - ), + filters=Filter.by_property(FaqSchema.COURSE_ID.value).equal(course_id), limit=1, return_properties=[FaqSchema.COURSE_NAME.value], ) @@ -540,5 +546,3 @@ def datetime_to_string(dt: Optional[datetime]) -> str: return "No date provided" else: return dt.strftime("%Y-%m-%d %H:%M:%S") - - diff --git a/app/pipeline/chat/exercise_chat_agent_pipeline.py b/app/pipeline/chat/exercise_chat_agent_pipeline.py index ff9e86da..676f96c6 100644 --- a/app/pipeline/chat/exercise_chat_agent_pipeline.py +++ b/app/pipeline/chat/exercise_chat_agent_pipeline.py @@ -533,7 +533,9 @@ def lecture_content_retrieval() -> str: ] ) - guide_response = (self.prompt | self.llm_small | StrOutputParser()).invoke( + guide_response = ( + self.prompt | self.llm_small | StrOutputParser() + ).invoke( { "response": out, } diff --git a/app/pipeline/faq_ingestion_pipeline.py b/app/pipeline/faq_ingestion_pipeline.py index 074485cd..70330ebf 100644 --- a/app/pipeline/faq_ingestion_pipeline.py +++ b/app/pipeline/faq_ingestion_pipeline.py @@ -10,7 +10,7 @@ from ..domain.data.faq_dto import FaqDTO from app.domain.ingestion.ingestion_pipeline_execution_dto import ( - FaqIngestionPipelineExecutionDto, + FaqIngestionPipelineExecutionDto, ) from ..llm.langchain import IrisLangchainChatModel from ..vector_database.faq_schema import FaqSchema, init_faq_schema @@ -25,6 +25,7 @@ batch_update_lock = threading.Lock() + class FaqIngestionPipeline(AbstractIngestion, Pipeline): def __init__( @@ -94,7 +95,6 @@ def batch_update(self, faq: FaqDTO): batch.add_object(properties=faq_dict, vector=embed_chunk) - except Exception as e: logger.error(f"Error updating faq: {e}") self.callback.error( @@ -103,9 +103,7 @@ def batch_update(self, faq: FaqDTO): tokens=self.tokens, ) - def delete_old_faqs( - self, faqs: list[FaqDTO] - ): + def delete_old_faqs(self, faqs: list[FaqDTO]): """ Delete the faq from the database """ @@ -129,7 +127,6 @@ def delete_faq(self, faq_id, course_id): self.collection.data.delete_many( where=Filter.by_property(FaqSchema.FAQ_ID.value).equal(faq_id) & Filter.by_property(FaqSchema.COURSE_ID.value).equal(course_id) - ) logging.info(f"successfully deleted faq with id {faq_id}") return True @@ -137,10 +134,8 @@ def delete_faq(self, faq_id, course_id): logger.error(f"Error deleting faq: {e}", exc_info=True) return False - def chunk_data(self, path: str) -> List[Dict[str, str]]: """ - Faqs are so small, they do not need to be chunked into smaller parts + Faqs are so small, they do not need to be chunked into smaller parts """ return - diff --git a/app/pipeline/shared/citation_pipeline.py b/app/pipeline/shared/citation_pipeline.py index 411f6b57..3ce9572d 100644 --- a/app/pipeline/shared/citation_pipeline.py +++ b/app/pipeline/shared/citation_pipeline.py @@ -16,6 +16,7 @@ from app.vector_database.lecture_schema import LectureSchema + class InformationType(str, Enum): PARAGRAPHS = "PARAGRAPHS" FAQS = "FAQS" @@ -45,7 +46,9 @@ def __init__(self): prompt_file_path = os.path.join(dirname, "..", "prompts", "citation_prompt.txt") with open(prompt_file_path, "r") as file: self.lecture_prompt_str = file.read() - prompt_file_path = os.path.join(dirname, "..", "prompts", "faq_citation_prompt.txt") + prompt_file_path = os.path.join( + dirname, "..", "prompts", "faq_citation_prompt.txt" + ) with open(prompt_file_path, "r") as file: self.faq_prompt_str = file.read() self.pipeline = self.llm | StrOutputParser() @@ -57,7 +60,6 @@ def __repr__(self): def __str__(self): return f"{self.__class__.__name__}(llm={self.llm})" - def create_formatted_lecture_string(self, paragraphs): """ Create a formatted string from the data @@ -68,7 +70,8 @@ def create_formatted_lecture_string(self, paragraphs): paragraph.get(LectureSchema.LECTURE_NAME.value), paragraph.get(LectureSchema.LECTURE_UNIT_NAME.value), paragraph.get(LectureSchema.PAGE_NUMBER.value), - paragraph.get(LectureSchema.LECTURE_UNIT_LINK.value) or "No link available", + paragraph.get(LectureSchema.LECTURE_UNIT_LINK.value) + or "No link available", paragraph.get(LectureSchema.PAGE_TEXT_CONTENT.value), ) formatted_string += lct @@ -86,13 +89,12 @@ def create_formatted_faq_string(self, faqs, base_url): faq.get(FaqSchema.COURSE_ID.value), faq.get(FaqSchema.QUESTION_TITLE.value), faq.get(FaqSchema.QUESTION_ANSWER.value), - f"{base_url}/courses/{faq.get(FaqSchema.COURSE_ID.value)}/faq/?faqId={faq.get(FaqSchema.FAQ_ID.value)}" + f"{base_url}/courses/{faq.get(FaqSchema.COURSE_ID.value)}/faq/?faqId={faq.get(FaqSchema.FAQ_ID.value)}", ) formatted_string += faq return formatted_string.replace("{", "{{").replace("}", "}}") - def __call__( self, information: Union[List[dict], List[str]], @@ -109,7 +111,9 @@ def __call__( paras = "" if information_type == InformationType.FAQS: - paras = self.create_formatted_faq_string(information, kwargs.get("base_url")) + paras = self.create_formatted_faq_string( + information, kwargs.get("base_url") + ) self.prompt_str = self.faq_prompt_str if information_type == InformationType.PARAGRAPHS: paras = self.create_formatted_lecture_string(information) diff --git a/app/retrieval/faq_retrieval.py b/app/retrieval/faq_retrieval.py index 830414a1..da7f4900 100644 --- a/app/retrieval/faq_retrieval.py +++ b/app/retrieval/faq_retrieval.py @@ -25,7 +25,10 @@ SystemMessagePromptTemplate, ) -from ..pipeline.prompts.faq_retrieval_prompts import faq_retriever_initial_prompt, write_hypothetical_answer_prompt +from ..pipeline.prompts.faq_retrieval_prompts import ( + faq_retriever_initial_prompt, + write_hypothetical_answer_prompt, +) from ..pipeline.prompts.lecture_retrieval_prompts import ( assessment_prompt, assessment_prompt_final, @@ -94,7 +97,6 @@ def __call__( problem_statement: str = None, exercise_title: str = None, base_url: str = None, - ) -> List[dict]: """ Retrieve faq data from the database. @@ -107,7 +109,7 @@ def __call__( result_limit=result_limit, course_language=course_language, course_name=course_name, - course_id=course_id + course_id=course_id, ) logging.info(f"FAQ retrival response, {response}") @@ -127,7 +129,6 @@ def __call__( logging.info(f"merged_chunks, {merged_chunks}") return merged_chunks - @traceable(name="Basic Faq Retrieval") def basic_faq_retrieval( self, @@ -261,7 +262,6 @@ def rewrite_elaborated_query( except Exception as e: raise e - @traceable(name="Retrieval: Search in DB") def search_in_db( self, @@ -284,7 +284,6 @@ def search_in_db( course_id ) - vec = self.llm_embedding.embed(query) return_value = self.collection.query.hybrid( query=query, @@ -367,9 +366,7 @@ def fetch_course_language(self, course_id): if course_id: # Fetch the first object that matches the course ID with the language property result = self.collection.query.fetch_objects( - filters=Filter.by_property(FaqSchema.COURSE_ID.value).equal( - course_id - ), + filters=Filter.by_property(FaqSchema.COURSE_ID.value).equal(course_id), limit=1, # We only need one object to check and retrieve the language return_properties=[FaqSchema.COURSE_LANGUAGE.value], ) @@ -382,4 +379,4 @@ def fetch_course_language(self, course_id): if fetched_language: course_language = fetched_language - return course_language \ No newline at end of file + return course_language diff --git a/app/web/routers/webhooks.py b/app/web/routers/webhooks.py index 79ac4522..0cf259c5 100644 --- a/app/web/routers/webhooks.py +++ b/app/web/routers/webhooks.py @@ -8,13 +8,15 @@ from fastapi import APIRouter, status, Depends from app.dependencies import TokenValidator from app.domain.ingestion.ingestion_pipeline_execution_dto import ( - IngestionPipelineExecutionDto, FaqIngestionPipelineExecutionDto, + IngestionPipelineExecutionDto, + FaqIngestionPipelineExecutionDto, ) from ..status.faq_ingestion_status_callback import FaqIngestionStatus from ..status.ingestion_status_callback import IngestionStatusCallback from ..status.lecture_deletion_status_callback import LecturesDeletionStatusCallback from ...domain.ingestion.deletionPipelineExecutionDto import ( - LecturesDeletionExecutionDto, FaqDeletionExecutionDto, + LecturesDeletionExecutionDto, + FaqDeletionExecutionDto, ) from ...pipeline.faq_ingestion_pipeline import FaqIngestionPipeline from ...pipeline.lecture_ingestion_pipeline import LectureIngestionPipeline @@ -85,12 +87,9 @@ def run_faq_update_pipeline_worker(dto: FaqIngestionPipelineExecutionDto): ) db = VectorDatabase() client = db.get_client() - pipeline = FaqIngestionPipeline( - client=client, dto=dto, callback=callback - ) + pipeline = FaqIngestionPipeline(client=client, dto=dto, callback=callback) pipeline() - except Exception as e: logger.error(f"Error Faq Ingestion pipeline: {e}") logger.error(traceback.format_exc()) @@ -117,7 +116,6 @@ def run_faq_delete_pipeline_worker(dto: FaqDeletionExecutionDto): pipeline = FaqIngestionPipeline(client=client, dto=None, callback=callback) pipeline.delete_faq(dto.faq.faq_id, dto.faq.course_id) - except Exception as e: logger.error(f"Error Ingestion pipeline: {e}") logger.error(traceback.format_exc()) @@ -125,6 +123,7 @@ def run_faq_delete_pipeline_worker(dto: FaqDeletionExecutionDto): finally: semaphore.release() + @router.post( "/lectures/fullIngestion", status_code=status.HTTP_202_ACCEPTED, @@ -150,6 +149,7 @@ def lecture_deletion_webhook(dto: LecturesDeletionExecutionDto): thread = Thread(target=run_lecture_deletion_pipeline_worker, args=(dto,)) thread.start() + @router.post( "/faqs", status_code=status.HTTP_202_ACCEPTED, @@ -163,11 +163,12 @@ def faq_ingestion_webhook(dto: FaqIngestionPipelineExecutionDto): thread.start() return + @router.post( "/faqs/delete", status_code=status.HTTP_202_ACCEPTED, dependencies=[Depends(TokenValidator())], - ) +) def faq_deletion_webhook(dto: FaqDeletionExecutionDto): """ Webhook endpoint to trigger the faq deletion pipeline @@ -176,5 +177,3 @@ def faq_deletion_webhook(dto: FaqDeletionExecutionDto): thread = Thread(target=run_faq_delete_pipeline_worker, args=(dto,)) thread.start() return - - diff --git a/app/web/status/faq_ingestion_status_callback.py b/app/web/status/faq_ingestion_status_callback.py index b60322fd..8642fb5e 100644 --- a/app/web/status/faq_ingestion_status_callback.py +++ b/app/web/status/faq_ingestion_status_callback.py @@ -21,7 +21,9 @@ def __init__( initial_stages: List[StageDTO] = None, faq_id: int = None, ): - url = f"{base_url}/api/public/pyris/webhooks/ingestion/faqs/runs/{run_id}/status" + url = ( + f"{base_url}/api/public/pyris/webhooks/ingestion/faqs/runs/{run_id}/status" + ) current_stage_index = len(initial_stages) if initial_stages else 0 stages = initial_stages or [] From 82a55c37f3ad67b01c57e0ff6b5cf70ccced35f6 Mon Sep 17 00:00:00 2001 From: Tim Cremer Date: Fri, 24 Jan 2025 13:32:34 +0100 Subject: [PATCH 10/19] Fix coderabit --- app/pipeline/chat/course_chat_pipeline.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/app/pipeline/chat/course_chat_pipeline.py b/app/pipeline/chat/course_chat_pipeline.py index c07923e0..7ec87328 100644 --- a/app/pipeline/chat/course_chat_pipeline.py +++ b/app/pipeline/chat/course_chat_pipeline.py @@ -11,7 +11,6 @@ from langchain_core.output_parsers import JsonOutputParser from langchain_core.prompts import ( ChatPromptTemplate, - SystemMessagePromptTemplate, ) from langchain_core.runnables import Runnable from langsmith import traceable @@ -311,12 +310,13 @@ def faq_content_retrieval() -> str: """ Use this tool to retrieve information from indexed FAQs. It is suitable when no other tool fits, you think it is a common question or the question is frequently asked, - or the question could be effectively answered by an FAQ. Also use this if the question is explicitly organizational and course-related. - An organizational question about the course might be "What is the course structure?" or "How do I enroll?" or exam related content like "When is the exam". - The tool performs a RAG retrieval based on the chat history to find the most relevant FAQs. Each FAQ follows this format: - FAQ ID, FAQ Question, FAQ Answer. - Respond to the query concisely and solely using the answer from the relevant FAQs. This tool should only be used once per query. - + or the question could be effectively answered by an FAQ. Also use this if the question is explicitly + organizational and course-related. An organizational question about the course might be + "What is the course structure?" or "How do I enroll?" or exam related content like "When is the exam". + The tool performs a RAG retrieval based on the chat history to find the most relevant FAQs. + Each FAQ follows this format: FAQ ID, FAQ Question, FAQ Answer. + Respond to the query concisely and solely using the answer from the relevant FAQs. + This tool should only be used once per query. """ self.callback.in_progress("Retrieving faq content ...") self.retrieved_faqs = self.faq_retriever( From f1621aec375d4fd1723e5a561282a2443910ce70 Mon Sep 17 00:00:00 2001 From: Tim Cremer Date: Fri, 24 Jan 2025 13:39:25 +0100 Subject: [PATCH 11/19] Fix docs --- app/pipeline/chat/course_chat_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/pipeline/chat/course_chat_pipeline.py b/app/pipeline/chat/course_chat_pipeline.py index 7ec87328..6d9b9894 100644 --- a/app/pipeline/chat/course_chat_pipeline.py +++ b/app/pipeline/chat/course_chat_pipeline.py @@ -528,7 +528,7 @@ def should_allow_faq_tool(self, course_id: int) -> bool: Checks if there are indexed faqs for the given course :param course_id: The course ID - :return: True if there are indexed lectures for the course, False otherwise + :return: True if there are indexed faqs for the course, False otherwise """ if course_id: # Fetch the first object that matches the course ID with the language property From dcb3e15923df30d16783b5672226c832e224ba32 Mon Sep 17 00:00:00 2001 From: Tim Cremer Date: Mon, 27 Jan 2025 13:48:19 +0100 Subject: [PATCH 12/19] Fixed the linter checks --- app/pipeline/chat/course_chat_pipeline.py | 2 +- app/pipeline/faq_ingestion_pipeline.py | 3 +-- app/pipeline/shared/citation_pipeline.py | 4 ++-- app/retrieval/faq_retrieval.py | 3 --- app/vector_database/faq_schema.py | 3 --- app/web/routers/webhooks.py | 1 - 6 files changed, 4 insertions(+), 12 deletions(-) diff --git a/app/pipeline/chat/course_chat_pipeline.py b/app/pipeline/chat/course_chat_pipeline.py index 6d9b9894..a3f227e9 100644 --- a/app/pipeline/chat/course_chat_pipeline.py +++ b/app/pipeline/chat/course_chat_pipeline.py @@ -309,7 +309,7 @@ def lecture_content_retrieval() -> str: def faq_content_retrieval() -> str: """ Use this tool to retrieve information from indexed FAQs. - It is suitable when no other tool fits, you think it is a common question or the question is frequently asked, + It is suitable when no other tool fits, it is a common question or the question is frequently asked, or the question could be effectively answered by an FAQ. Also use this if the question is explicitly organizational and course-related. An organizational question about the course might be "What is the course structure?" or "How do I enroll?" or exam related content like "When is the exam". diff --git a/app/pipeline/faq_ingestion_pipeline.py b/app/pipeline/faq_ingestion_pipeline.py index 70330ebf..0f41ca11 100644 --- a/app/pipeline/faq_ingestion_pipeline.py +++ b/app/pipeline/faq_ingestion_pipeline.py @@ -3,7 +3,6 @@ from asyncio.log import logger from typing import Optional, List, Dict from langchain_core.output_parsers import StrOutputParser -from openai import OpenAI from weaviate import WeaviateClient from weaviate.classes.query import Filter from . import Pipeline @@ -128,7 +127,7 @@ def delete_faq(self, faq_id, course_id): where=Filter.by_property(FaqSchema.FAQ_ID.value).equal(faq_id) & Filter.by_property(FaqSchema.COURSE_ID.value).equal(course_id) ) - logging.info(f"successfully deleted faq with id {faq_id}") + logger.info(f"successfully deleted faq with id {faq_id}") return True except Exception as e: logger.error(f"Error deleting faq: {e}", exc_info=True) diff --git a/app/pipeline/shared/citation_pipeline.py b/app/pipeline/shared/citation_pipeline.py index 3ce9572d..d89efbd5 100644 --- a/app/pipeline/shared/citation_pipeline.py +++ b/app/pipeline/shared/citation_pipeline.py @@ -1,4 +1,3 @@ -import logging import os from asyncio.log import logger from enum import Enum @@ -104,8 +103,9 @@ def __call__( ) -> str: """ Runs the pipeline - :param information: List of information which can be list of dicts or list of strings. Used to augment the response + :param information: List of info as list of dicts or strings to augment response :param query: The query + :param information_type: The type of information provided. can be either lectures or faqs :return: Selected file content """ paras = "" diff --git a/app/retrieval/faq_retrieval.py b/app/retrieval/faq_retrieval.py index da7f4900..6a1fe16b 100644 --- a/app/retrieval/faq_retrieval.py +++ b/app/retrieval/faq_retrieval.py @@ -112,8 +112,6 @@ def __call__( course_id=course_id, ) - logging.info(f"FAQ retrival response, {response}") - basic_retrieved_faqs: list[dict[str, dict]] = [ {"id": obj.uuid.int, "properties": obj.properties} for obj in response.objects @@ -126,7 +124,6 @@ def __call__( basic_retrieved_faqs, hyde_retrieved_faqs ) - logging.info(f"merged_chunks, {merged_chunks}") return merged_chunks @traceable(name="Basic Faq Retrieval") diff --git a/app/vector_database/faq_schema.py b/app/vector_database/faq_schema.py index 325e8f97..abf97023 100644 --- a/app/vector_database/faq_schema.py +++ b/app/vector_database/faq_schema.py @@ -1,4 +1,3 @@ -import logging from enum import Enum from weaviate.classes.config import Property @@ -28,8 +27,6 @@ def init_faq_schema(client: WeaviateClient) -> Collection: """ if client.collections.exists(FaqSchema.COLLECTION_NAME.value): collection = client.collections.get(FaqSchema.COLLECTION_NAME.value) - properties = collection.config.get(simple=True).properties - # Check and add 'course_language' property if missing if not any( property.name == FaqSchema.COURSE_LANGUAGE.value diff --git a/app/web/routers/webhooks.py b/app/web/routers/webhooks.py index 0cf259c5..b14f07b5 100644 --- a/app/web/routers/webhooks.py +++ b/app/web/routers/webhooks.py @@ -173,7 +173,6 @@ def faq_deletion_webhook(dto: FaqDeletionExecutionDto): """ Webhook endpoint to trigger the faq deletion pipeline """ - logging.info("Starting faq deletion") thread = Thread(target=run_faq_delete_pipeline_worker, args=(dto,)) thread.start() return From bf523423e397449a245db2678e2f34b5d23a8465 Mon Sep 17 00:00:00 2001 From: Tim Cremer Date: Mon, 27 Jan 2025 13:56:15 +0100 Subject: [PATCH 13/19] Fixed the linter checks --- app/pipeline/faq_ingestion_pipeline.py | 1 - app/web/routers/webhooks.py | 1 - 2 files changed, 2 deletions(-) diff --git a/app/pipeline/faq_ingestion_pipeline.py b/app/pipeline/faq_ingestion_pipeline.py index 0f41ca11..d8316caf 100644 --- a/app/pipeline/faq_ingestion_pipeline.py +++ b/app/pipeline/faq_ingestion_pipeline.py @@ -1,4 +1,3 @@ -import logging import threading from asyncio.log import logger from typing import Optional, List, Dict diff --git a/app/web/routers/webhooks.py b/app/web/routers/webhooks.py index b14f07b5..3b49f3d1 100644 --- a/app/web/routers/webhooks.py +++ b/app/web/routers/webhooks.py @@ -1,4 +1,3 @@ -import logging import traceback from asyncio.log import logger from threading import Thread, Semaphore From 6163e6fa6e8ada434656fbad7317b74c24c049b5 Mon Sep 17 00:00:00 2001 From: Tim Cremer Date: Wed, 29 Jan 2025 15:28:50 +0100 Subject: [PATCH 14/19] Refactored FAQ retrival pipeline to reduce code duplication --- app/pipeline/faq_ingestion_pipeline.py | 4 +- app/retrieval/basic_retrieval.py | 298 +++++++++++++++++++++ app/retrieval/faq_retrieval.py | 344 ++----------------------- 3 files changed, 318 insertions(+), 328 deletions(-) create mode 100644 app/retrieval/basic_retrieval.py diff --git a/app/pipeline/faq_ingestion_pipeline.py b/app/pipeline/faq_ingestion_pipeline.py index d8316caf..ecfb9d1a 100644 --- a/app/pipeline/faq_ingestion_pipeline.py +++ b/app/pipeline/faq_ingestion_pipeline.py @@ -5,6 +5,7 @@ from weaviate import WeaviateClient from weaviate.classes.query import Filter from . import Pipeline +from .lecture_ingestion_pipeline import batch_update_lock from ..domain.data.faq_dto import FaqDTO from app.domain.ingestion.ingestion_pipeline_execution_dto import ( @@ -21,7 +22,8 @@ ) from ..web.status.faq_ingestion_status_callback import FaqIngestionStatus -batch_update_lock = threading.Lock() +# we use the same lock as the lecture ingestion pipeline +batch_update_lock = batch_update_lock class FaqIngestionPipeline(AbstractIngestion, Pipeline): diff --git a/app/retrieval/basic_retrieval.py b/app/retrieval/basic_retrieval.py new file mode 100644 index 00000000..132e7a40 --- /dev/null +++ b/app/retrieval/basic_retrieval.py @@ -0,0 +1,298 @@ +from abc import abstractmethod, ABC +from typing import List, Optional +from langsmith import traceable +from weaviate import WeaviateClient +from weaviate.classes.query import Filter +from app.common.token_usage_dto import TokenUsageDTO +from app.common.PipelineEnum import PipelineEnum +from ..common.message_converters import convert_iris_message_to_langchain_message +from ..common.pyris_message import PyrisMessage +from ..llm.langchain import IrisLangchainChatModel +from ..pipeline import Pipeline +from app.llm import ( + BasicRequestHandler, + CompletionArguments, + CapabilityRequestHandler, + RequirementList, +) +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate +import concurrent.futures +import logging + +logger = logging.getLogger(__name__) + + +def merge_retrieved_chunks( + basic_retrieved_faq_chunks, hyde_retrieved_faq_chunks +) -> List[dict]: + """ + Merge the retrieved chunks from the basic and hyde retrieval methods. This function ensures that for any + duplicate IDs, the properties from hyde_retrieved_faq_chunks will overwrite those from + basic_retrieved_faq_chunks. + """ + merged_chunks = {} + for chunk in basic_retrieved_faq_chunks: + merged_chunks[chunk["id"]] = chunk["properties"] + + for chunk in hyde_retrieved_faq_chunks: + merged_chunks[chunk["id"]] = chunk["properties"] + + return [properties for uuid, properties in merged_chunks.items()] + + +def _add_last_four_messages_to_prompt( + prompt, + chat_history: List[PyrisMessage], +): + """ + Adds the chat history and user question to the prompt + :param chat_history: The chat history + :param user_question: The user question + :return: The prompt with the chat history + """ + if chat_history is not None and len(chat_history) > 0: + num_messages_to_take = min(len(chat_history), 4) + last_messages = chat_history[-num_messages_to_take:] + chat_history_messages = [ + convert_iris_message_to_langchain_message(message) + for message in last_messages + ] + prompt += chat_history_messages + return prompt + + +class BaseRetrieval(Pipeline, ABC): + """ + Base class for retrieval pipelines. + """ + + tokens: List[TokenUsageDTO] + + @abstractmethod + def __call__(self, *args, **kwargs): + """Muss in der konkreten Implementierung überschrieben werden""" + pass + + def __init__(self, client: WeaviateClient, schema_init_func, **kwargs): + super().__init__( + implementation_id=kwargs.get("implementation_id", "base_retrieval_pipeline") + ) + request_handler = CapabilityRequestHandler( + requirements=RequirementList( + gpt_version_equivalent=4.25, + context_length=16385, + privacy_compliance=True, + ) + ) + completion_args = CompletionArguments(temperature=0, max_tokens=2000) + self.llm = IrisLangchainChatModel( + request_handler=request_handler, completion_args=completion_args + ) + self.llm_embedding = BasicRequestHandler("embedding-small") + self.pipeline = self.llm | StrOutputParser() + self.collection = schema_init_func(client) + self.tokens = [] + + @traceable(name="Retrieval: Question Assessment") + def assess_question( + self, + chat_history: list[PyrisMessage], + student_query: str, + assessment_prompt: str, + assessment_prompt_final: str, + ) -> bool: + prompt = ChatPromptTemplate.from_messages( + [ + ("system", assessment_prompt), + ] + ) + prompt = _add_last_four_messages_to_prompt(prompt, chat_history) + prompt += ChatPromptTemplate.from_messages( + [ + ("user", student_query), + ] + ) + prompt += ChatPromptTemplate.from_messages( + [ + ("system", assessment_prompt_final), + ] + ) + + try: + response = (prompt | self.pipeline).invoke({}) + logger.info(f"Response from assessment pipeline: {response}") + return response == "YES" + except Exception as e: + raise e + + @traceable(name="Retrieval: Rewrite Student Query") + def rewrite_student_query( + self, + chat_history: list[PyrisMessage], + student_query: str, + course_language: str, + course_name: str, + initial_prompt: str, + rewrite_prompt: str, + pipeline_enum: PipelineEnum, + ) -> str: + """ + Rewrite the student query. + """ + prompt = ChatPromptTemplate.from_messages( + [ + ("system", initial_prompt), + ] + ) + prompt = _add_last_four_messages_to_prompt(prompt, chat_history) + prompt += SystemMessagePromptTemplate.from_template(rewrite_prompt) + prompt_val = prompt.format_messages( + course_language=course_language, + course_name=course_name, + student_query=student_query, + ) + prompt = ChatPromptTemplate.from_messages(prompt_val) + try: + response = (prompt | self.pipeline).invoke({}) + token_usage = self.llm.tokens + token_usage.pipeline = pipeline_enum + self.tokens.append(self.llm.tokens) + return response + except Exception as e: + raise e + + @traceable(name="Retrieval: Search in DB") + def search_in_db( + self, + query: str, + hybrid_factor: float, + result_limit: int, + schema_properties: List[str], + course_id: Optional[int] = None, + base_url: Optional[str] = None, + course_id_property: str = "course_id", + base_url_property: str = "base_url", + ): + """ + Search the database for the given query. + """ + logger.info(f"Searching in the database for query: {query}") + filter_weaviate = None + + if course_id: + filter_weaviate = Filter.by_property(course_id_property).equal(course_id) + if base_url: + filter_weaviate &= Filter.by_property(base_url_property).equal(base_url) + + vec = self.llm_embedding.embed(query) + return self.collection.query.hybrid( + query=query, + alpha=hybrid_factor, + vector=vec, + return_properties=schema_properties, + limit=result_limit, + filters=filter_weaviate, + ) + + @traceable(name="Retrieval: Run Parallel Rewrite Tasks") + def run_parallel_rewrite_tasks( + self, + chat_history: list[PyrisMessage], + student_query: str, + result_limit: int, + course_language: str, + initial_prompt: str, + rewrite_prompt: str, + hypothetical_answer_prompt: str, + pipeline_enum: PipelineEnum, + course_name: Optional[str] = None, + course_id: Optional[int] = None, + base_url: Optional[str] = None, + problem_statement: Optional[str] = None, + exercise_title: Optional[str] = None, + ): + """ + Run the rewrite tasks in parallel. + """ + with concurrent.futures.ThreadPoolExecutor() as executor: + rewritten_query_future = executor.submit( + self.rewrite_student_query, + chat_history, + student_query, + course_language, + course_name, + initial_prompt, + rewrite_prompt, + pipeline_enum, + ) + hypothetical_answer_query_future = executor.submit( + self.rewrite_student_query, + chat_history, + student_query, + course_language, + course_name, + initial_prompt, + hypothetical_answer_prompt, + pipeline_enum, + ) + + rewritten_query = rewritten_query_future.result() + hypothetical_answer_query = hypothetical_answer_query_future.result() + + with concurrent.futures.ThreadPoolExecutor() as executor: + response_future = executor.submit( + self.search_in_db, + query=rewritten_query, + hybrid_factor=0.9, + result_limit=result_limit, + schema_properties=self.get_schema_properties(), + course_id=course_id, + base_url=base_url, + ) + response_hyde_future = executor.submit( + self.search_in_db, + query=hypothetical_answer_query, + hybrid_factor=0.9, + result_limit=result_limit, + schema_properties=self.get_schema_properties(), + course_id=course_id, + base_url=base_url, + ) + + response = response_future.result() + response_hyde = response_hyde_future.result() + + return response, response_hyde + + @abstractmethod + def get_schema_properties(self) -> List[str]: + """ + Abstract method to be implemented by subclasses to return the schema properties. + """ + raise NotImplementedError + + def fetch_course_language( + self, course_id: int, course_language_property: str = "course_language" + ) -> str: + """ + Fetch the language of the course based on the course ID. + If no specific language is set, it defaults to English. + """ + course_language = "english" + + if course_id: + result = self.collection.query.fetch_objects( + filters=Filter.by_property("course_id").equal(course_id), + limit=1, + return_properties=[course_language_property], + ) + + if result.objects: + fetched_language = result.objects[0].properties.get( + course_language_property + ) + if fetched_language: + course_language = fetched_language + + return course_language diff --git a/app/retrieval/faq_retrieval.py b/app/retrieval/faq_retrieval.py index 6a1fe16b..c2f9a455 100644 --- a/app/retrieval/faq_retrieval.py +++ b/app/retrieval/faq_retrieval.py @@ -1,90 +1,35 @@ import logging from typing import List - from langsmith import traceable from weaviate import WeaviateClient -from weaviate.classes.query import Filter - -from app.common.token_usage_dto import TokenUsageDTO from app.common.PipelineEnum import PipelineEnum -from .lecture_retrieval import _add_last_four_messages_to_prompt +from .basic_retrieval import BaseRetrieval, merge_retrieved_chunks from ..common.pyris_message import PyrisMessage -from ..llm.langchain import IrisLangchainChatModel -from ..pipeline import Pipeline - -from app.llm import ( - BasicRequestHandler, - CompletionArguments, - CapabilityRequestHandler, - RequirementList, -) -from app.pipeline.shared.reranker_pipeline import RerankerPipeline -from langchain_core.output_parsers import StrOutputParser -from langchain_core.prompts import ( - ChatPromptTemplate, - SystemMessagePromptTemplate, -) - from ..pipeline.prompts.faq_retrieval_prompts import ( faq_retriever_initial_prompt, write_hypothetical_answer_prompt, ) from ..pipeline.prompts.lecture_retrieval_prompts import ( - assessment_prompt, - assessment_prompt_final, rewrite_student_query_prompt, ) -import concurrent.futures - from ..vector_database.faq_schema import FaqSchema, init_faq_schema - logger = logging.getLogger(__name__) -def merge_retrieved_chunks( - basic_retrieved_faq_chunks, hyde_retrieved_faq_chunks -) -> List[dict]: - """ - Merge the retrieved chunks from the basic and hyde retrieval methods. This function ensures that for any - duplicate IDs, the properties from hyde_retrieved_faq_chunks will overwrite those from - basic_retrieved_faq_chunks. - """ - merged_chunks = {} - for chunk in basic_retrieved_faq_chunks: - merged_chunks[chunk["id"]] = chunk["properties"] - - for chunk in hyde_retrieved_faq_chunks: - merged_chunks[chunk["id"]] = chunk["properties"] - - return [properties for uuid, properties in merged_chunks.items()] - - -class FaqRetrieval(Pipeline): - """ - Class for retrieving faq data from the database. - """ - - tokens: List[TokenUsageDTO] - +class FaqRetrieval(BaseRetrieval): def __init__(self, client: WeaviateClient, **kwargs): - super().__init__(implementation_id="faq_retrieval_pipeline") - request_handler = CapabilityRequestHandler( - requirements=RequirementList( - gpt_version_equivalent=4.25, - context_length=16385, - privacy_compliance=True, - ) + super().__init__( + client, init_faq_schema, implementation_id="faq_retrieval_pipeline" ) - completion_args = CompletionArguments(temperature=0, max_tokens=2000) - self.llm = IrisLangchainChatModel( - request_handler=request_handler, completion_args=completion_args - ) - self.llm_embedding = BasicRequestHandler("embedding-small") - self.pipeline = self.llm | StrOutputParser() - self.collection = init_faq_schema(client) - self.reranker_pipeline = RerankerPipeline() - self.tokens = [] + + def get_schema_properties(self) -> List[str]: + return [ + FaqSchema.COURSE_ID.value, + FaqSchema.FAQ_ID.value, + FaqSchema.QUESTION_TITLE.value, + FaqSchema.QUESTION_ANSWER.value, + ] @traceable(name="Full Faq Retrieval") def __call__( @@ -98,9 +43,6 @@ def __call__( exercise_title: str = None, base_url: str = None, ) -> List[dict]: - """ - Retrieve faq data from the database. - """ course_language = self.fetch_course_language(course_id) response, response_hyde = self.run_parallel_rewrite_tasks( @@ -108,6 +50,10 @@ def __call__( student_query=student_query, result_limit=result_limit, course_language=course_language, + initial_prompt=faq_retriever_initial_prompt, + rewrite_prompt=rewrite_student_query_prompt, + hypothetical_answer_prompt=write_hypothetical_answer_prompt, + pipeline_enum=PipelineEnum.IRIS_FAQ_RETRIEVAL_PIPELINE, course_name=course_name, course_id=course_id, ) @@ -120,260 +66,4 @@ def __call__( {"id": obj.uuid.int, "properties": obj.properties} for obj in response_hyde.objects ] - merged_chunks = merge_retrieved_chunks( - basic_retrieved_faqs, hyde_retrieved_faqs - ) - - return merged_chunks - - @traceable(name="Basic Faq Retrieval") - def basic_faq_retrieval( - self, - chat_history: list[PyrisMessage], - student_query: str, - result_limit: int, - course_name: str = None, - course_id: int = None, - course_language: str = None, - ) -> list[dict[str, dict]]: - """ - Basic retrieval for pipelines that need performance and fast answers. - """ - if not self.assess_question(chat_history, student_query): - return [] - - rewritten_query = self.rewrite_student_query( - chat_history, student_query, course_language, course_name - ) - response = self.search_in_db( - query=rewritten_query, - hybrid_factor=0.9, - result_limit=result_limit, - course_id=course_id, - ) - - basic_retrieved_faq_chunks: list[dict[str, dict]] = [ - {"id": obj.uuid.int, "properties": obj.properties} - for obj in response.objects - ] - return basic_retrieved_faq_chunks - - @traceable(name="Retrieval: Question Assessment") - def assess_question( - self, chat_history: list[PyrisMessage], student_query: str - ) -> bool: - prompt = ChatPromptTemplate.from_messages( - [ - ("system", assessment_prompt), - ] - ) - prompt = _add_last_four_messages_to_prompt(prompt, chat_history) - prompt += ChatPromptTemplate.from_messages( - [ - ("user", student_query), - ] - ) - prompt += ChatPromptTemplate.from_messages( - [ - ("system", assessment_prompt_final), - ] - ) - - try: - response = (prompt | self.pipeline).invoke({}) - logger.info(f"Response from assessment pipeline: {response}") - return response == "YES" - except Exception as e: - raise e - - @traceable(name="Retrieval: Rewrite Student Query") - def rewrite_student_query( - self, - chat_history: list[PyrisMessage], - student_query: str, - course_language: str, - course_name: str, - ) -> str: - """ - Rewrite the student query. - """ - prompt = ChatPromptTemplate.from_messages( - [ - ("system", faq_retriever_initial_prompt), - ] - ) - prompt = _add_last_four_messages_to_prompt(prompt, chat_history) - prompt += SystemMessagePromptTemplate.from_template( - rewrite_student_query_prompt - ) - prompt_val = prompt.format_messages( - course_language=course_language, - course_name=course_name, - student_query=student_query, - ) - prompt = ChatPromptTemplate.from_messages(prompt_val) - try: - response = (prompt | self.pipeline).invoke({}) - token_usage = self.llm.tokens - token_usage.pipeline = PipelineEnum.IRIS_FAQ_RETRIEVAL_PIPELINE - self.tokens.append(self.llm.tokens) - return response - except Exception as e: - raise e - - @traceable(name="Retrieval: Rewrite Elaborated Query") - def rewrite_elaborated_query( - self, - chat_history: list[PyrisMessage], - student_query: str, - course_language: str, - course_name: str, - ) -> str: - """ - Rewrite the student query to generate fitting faq content and embed it. - To extract more relevant content from the vector database. - """ - prompt = ChatPromptTemplate.from_messages( - [ - ("system", write_hypothetical_answer_prompt), - ] - ) - prompt = _add_last_four_messages_to_prompt(prompt, chat_history) - prompt += ChatPromptTemplate.from_messages( - [ - ("user", student_query), - ] - ) - prompt_val = prompt.format_messages( - course_language=course_language, - course_name=course_name, - ) - prompt = ChatPromptTemplate.from_messages(prompt_val) - - try: - response = (prompt | self.pipeline).invoke({}) - token_usage = self.llm.tokens - token_usage.pipeline = PipelineEnum.IRIS_FAQ_RETRIEVAL_PIPELINE - self.tokens.append(self.llm.tokens) - return response - except Exception as e: - raise e - - @traceable(name="Retrieval: Search in DB") - def search_in_db( - self, - query: str, - hybrid_factor: float, - result_limit: int, - course_id: int = None, - ): - """ - Search the database for the given query. - """ - logger.info(f"Searching in the database for query: {query}") - # Initialize filter to None by default - filter_weaviate = None - - # Check if course_id is provided - if course_id: - # Create a filter for course_id - filter_weaviate = Filter.by_property(FaqSchema.COURSE_ID.value).equal( - course_id - ) - - vec = self.llm_embedding.embed(query) - return_value = self.collection.query.hybrid( - query=query, - alpha=hybrid_factor, - vector=vec, - return_properties=[ - FaqSchema.COURSE_ID.value, - FaqSchema.FAQ_ID.value, - FaqSchema.QUESTION_TITLE.value, - FaqSchema.QUESTION_ANSWER.value, - ], - limit=result_limit, - filters=filter_weaviate, - ) - - return return_value - - @traceable(name="Retrieval: Run Parallel Rewrite Tasks") - def run_parallel_rewrite_tasks( - self, - chat_history: list[PyrisMessage], - student_query: str, - result_limit: int, - course_language: str, - course_name: str = None, - course_id: int = None, - ): - - with concurrent.futures.ThreadPoolExecutor() as executor: - # Schedule the rewrite tasks to run in parallel - rewritten_query_future = executor.submit( - self.rewrite_student_query, - chat_history, - student_query, - course_language, - course_name, - ) - hypothetical_answer_query_future = executor.submit( - self.rewrite_elaborated_query, - chat_history, - student_query, - course_language, - course_name, - ) - - # Get the results once both tasks are complete - rewritten_query = rewritten_query_future.result() - hypothetical_answer_query = hypothetical_answer_query_future.result() - - # Execute the database search tasks - with concurrent.futures.ThreadPoolExecutor() as executor: - response_future = executor.submit( - self.search_in_db, - query=rewritten_query, - hybrid_factor=0.9, - result_limit=result_limit, - course_id=course_id, - ) - response_hyde_future = executor.submit( - self.search_in_db, - query=hypothetical_answer_query, - hybrid_factor=0.9, - result_limit=result_limit, - course_id=course_id, - ) - - # Get the results once both tasks are complete - response = response_future.result() - response_hyde = response_hyde_future.result() - - return response, response_hyde - - def fetch_course_language(self, course_id): - """ - Fetch the language of the course based on the course ID. - If no specific language is set, it defaults to English. - """ - course_language = "english" - - if course_id: - # Fetch the first object that matches the course ID with the language property - result = self.collection.query.fetch_objects( - filters=Filter.by_property(FaqSchema.COURSE_ID.value).equal(course_id), - limit=1, # We only need one object to check and retrieve the language - return_properties=[FaqSchema.COURSE_LANGUAGE.value], - ) - - # Check if the result has objects and retrieve the language - if result.objects: - fetched_language = result.objects[0].properties.get( - FaqSchema.COURSE_LANGUAGE.value - ) - if fetched_language: - course_language = fetched_language - - return course_language + return merge_retrieved_chunks(basic_retrieved_faqs, hyde_retrieved_faqs) From 2fcdf4e4ca929cc26793d2726ce1dec9322c483f Mon Sep 17 00:00:00 2001 From: Tim Cremer Date: Thu, 30 Jan 2025 12:46:59 +0100 Subject: [PATCH 15/19] Refactored FAQ retrival pipeline to reduce code duplication --- app/pipeline/chat/course_chat_pipeline.py | 29 ++------------ .../chat/exercise_chat_agent_pipeline.py | 38 +++++++++++++++++- app/retrieval/faq_retrieval_utils.py | 40 +++++++++++++++++++ 3 files changed, 79 insertions(+), 28 deletions(-) create mode 100644 app/retrieval/faq_retrieval_utils.py diff --git a/app/pipeline/chat/course_chat_pipeline.py b/app/pipeline/chat/course_chat_pipeline.py index a3f227e9..46d558a9 100644 --- a/app/pipeline/chat/course_chat_pipeline.py +++ b/app/pipeline/chat/course_chat_pipeline.py @@ -43,6 +43,7 @@ from ...domain import CourseChatPipelineExecutionDTO from app.common.PipelineEnum import PipelineEnum from ...retrieval.faq_retrieval import FaqRetrieval +from ...retrieval.faq_retrieval_utils import should_allow_faq_tool, format_faqs from ...retrieval.lecture_retrieval import LectureRetrieval from ...vector_database.database import VectorDatabase from ...vector_database.faq_schema import FaqSchema @@ -328,14 +329,7 @@ def faq_content_retrieval() -> str: base_url=dto.settings.artemis_base_url, ) - result = "" - for faq in self.retrieved_faqs: - res = "[FAQ ID: {}, FAQ Question: {}, FAQ Answer: {}]".format( - faq.get(FaqSchema.FAQ_ID.value), - faq.get(FaqSchema.QUESTION_TITLE.value), - faq.get(FaqSchema.QUESTION_ANSWER.value), - ) - result += res + result = format_faqs(self.retrieved_faqs) return result if dto.user.id % 3 < 2: @@ -436,7 +430,7 @@ def faq_content_retrieval() -> str: if self.should_allow_lecture_tool(dto.course.id): tool_list.append(lecture_content_retrieval) - if self.should_allow_faq_tool(dto.course.id): + if should_allow_faq_tool(self.db, dto.course.id): tool_list.append(faq_content_retrieval) tools = generate_structured_tools_from_functions(tool_list) @@ -523,23 +517,6 @@ def should_allow_lecture_tool(self, course_id: int) -> bool: return len(result.objects) > 0 return False - def should_allow_faq_tool(self, course_id: int) -> bool: - """ - Checks if there are indexed faqs for the given course - - :param course_id: The course ID - :return: True if there are indexed faqs for the course, False otherwise - """ - if course_id: - # Fetch the first object that matches the course ID with the language property - result = self.db.faqs.query.fetch_objects( - filters=Filter.by_property(FaqSchema.COURSE_ID.value).equal(course_id), - limit=1, - return_properties=[FaqSchema.COURSE_NAME.value], - ) - return len(result.objects) > 0 - return False - def datetime_to_string(dt: Optional[datetime]) -> str: if dt is None: diff --git a/app/pipeline/chat/exercise_chat_agent_pipeline.py b/app/pipeline/chat/exercise_chat_agent_pipeline.py index 676f96c6..65c05390 100644 --- a/app/pipeline/chat/exercise_chat_agent_pipeline.py +++ b/app/pipeline/chat/exercise_chat_agent_pipeline.py @@ -39,8 +39,11 @@ from ...llm import CapabilityRequestHandler, RequirementList from ...llm import CompletionArguments from ...llm.langchain import IrisLangchainChatModel +from ...retrieval.faq_retrieval import FaqRetrieval +from ...retrieval.faq_retrieval_utils import should_allow_faq_tool from ...retrieval.lecture_retrieval import LectureRetrieval from ...vector_database.database import VectorDatabase +from ...vector_database.faq_schema import FaqSchema from ...vector_database.lecture_schema import LectureSchema from weaviate.collections.classes.filters import Filter from ...web.status.status_update import ExerciseChatStatusCallback @@ -103,6 +106,7 @@ class ExerciseChatAgentPipeline(Pipeline): prompt: ChatPromptTemplate variant: str event: str | None + retrieved_faqs: List[dict] = None def __init__( self, @@ -136,7 +140,8 @@ def __init__( # Create the pipelines self.db = VectorDatabase() self.suggestion_pipeline = InteractionSuggestionPipeline(variant="exercise") - self.retriever = LectureRetrieval(self.db.client) + self.lecture_retriever = LectureRetrieval(self.db.client) + self.faq_retriever = FaqRetrieval(self.db.client) self.reranker_pipeline = RerankerPipeline() self.code_feedback_pipeline = CodeFeedbackPipeline() self.pipeline = self.llm_big | JsonOutputParser() @@ -373,7 +378,7 @@ def lecture_content_retrieval() -> str: Only use this once. """ self.callback.in_progress("Retrieving lecture content ...") - self.retrieved_paragraphs = self.retriever( + self.retrieved_paragraphs = self.lectureRetriever( chat_history=chat_history, student_query=query.contents[0].text_content, result_limit=5, @@ -393,6 +398,31 @@ def lecture_content_retrieval() -> str: result += lct return result + def faq_content_retrieval() -> str: + """ + Use this tool to retrieve information from indexed FAQs. + It is suitable when no other tool fits, it is a common question or the question is frequently asked, + or the question could be effectively answered by an FAQ. Also use this if the question is explicitly + organizational and course-related. An organizational question about the course might be + "What is the course structure?" or "How do I enroll?" or exam related content like "When is the exam". + The tool performs a RAG retrieval based on the chat history to find the most relevant FAQs. + Each FAQ follows this format: FAQ ID, FAQ Question, FAQ Answer. + Respond to the query concisely and solely using the answer from the relevant FAQs. + This tool should only be used once per query. + """ + self.callback.in_progress("Retrieving faq content ...") + self.retrieved_faqs = self.faq_retriever( + chat_history=history, + student_query=query.contents[0].text_content, + result_limit=10, + course_name=dto.course.name, + course_id=dto.course.id, + base_url=dto.settings.artemis_base_url, + ) + + result = format_faqs(self.retrieved_faqs) + return result + iris_initial_system_prompt = tell_iris_initial_system_prompt chat_history_exists_prompt = tell_chat_history_exists_prompt no_chat_history_prompt = tell_no_chat_history_prompt @@ -511,6 +541,10 @@ def lecture_content_retrieval() -> str: ] if self.should_allow_lecture_tool(dto.course.id): tool_list.append(lecture_content_retrieval) + + if should_allow_faq_tool(self.db, dto.course.id): + tool_list.append(faq_content_retrieval) + tools = generate_structured_tools_from_functions(tool_list) agent = create_tool_calling_agent( llm=self.llm_big, tools=tools, prompt=self.prompt diff --git a/app/retrieval/faq_retrieval_utils.py b/app/retrieval/faq_retrieval_utils.py new file mode 100644 index 00000000..2a65873b --- /dev/null +++ b/app/retrieval/faq_retrieval_utils.py @@ -0,0 +1,40 @@ +from weaviate.collections.classes.filters import Filter +from app.vector_database.database import VectorDatabase +from app.vector_database.faq_schema import FaqSchema + + +def should_allow_faq_tool(db: VectorDatabase, course_id: int) -> bool: + """ + Checks if there are indexed faqs for the given course + + :param db: The vector database on which the faqs are indexed + :param course_id: The course ID + :return: True if there are indexed faqs for the course, False otherwise + """ + if course_id: + # Fetch the first object that matches the course ID with the language property + result = db.faqs.query.fetch_objects( + filters=Filter.by_property(FaqSchema.COURSE_ID.value).equal(course_id), + limit=1, + return_properties=[FaqSchema.COURSE_NAME.value], + ) + return len(result.objects) > 0 + return False + + +def format_faqs(retrieved_faqs): + """ + Formatiert die abgerufenen FAQs in einen String. + + :param retrieved_faqs: Liste der abgerufenen FAQs + :return: Formatierter String mit den FAQ-Daten + """ + result = "" + for faq in retrieved_faqs: + res = "[FAQ ID: {}, FAQ Question: {}, FAQ Answer: {}]".format( + faq.get(FaqSchema.FAQ_ID.value), + faq.get(FaqSchema.QUESTION_TITLE.value), + faq.get(FaqSchema.QUESTION_ANSWER.value), + ) + result += res + return result From 430b777e5f60846ae6efd2386b3a83b9434934df Mon Sep 17 00:00:00 2001 From: Tim Cremer Date: Thu, 30 Jan 2025 12:47:12 +0100 Subject: [PATCH 16/19] Refactored FAQ retrival pipeline to reduce code duplication --- app/pipeline/chat/exercise_chat_agent_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/app/pipeline/chat/exercise_chat_agent_pipeline.py b/app/pipeline/chat/exercise_chat_agent_pipeline.py index 65c05390..b68c73d8 100644 --- a/app/pipeline/chat/exercise_chat_agent_pipeline.py +++ b/app/pipeline/chat/exercise_chat_agent_pipeline.py @@ -40,7 +40,7 @@ from ...llm import CompletionArguments from ...llm.langchain import IrisLangchainChatModel from ...retrieval.faq_retrieval import FaqRetrieval -from ...retrieval.faq_retrieval_utils import should_allow_faq_tool +from ...retrieval.faq_retrieval_utils import should_allow_faq_tool, format_faqs from ...retrieval.lecture_retrieval import LectureRetrieval from ...vector_database.database import VectorDatabase from ...vector_database.faq_schema import FaqSchema @@ -412,7 +412,7 @@ def faq_content_retrieval() -> str: """ self.callback.in_progress("Retrieving faq content ...") self.retrieved_faqs = self.faq_retriever( - chat_history=history, + chat_history=chat_history, student_query=query.contents[0].text_content, result_limit=10, course_name=dto.course.name, From cf2feac8bd3a2d684cfdb634ded6e6edabada42d Mon Sep 17 00:00:00 2001 From: Tim Cremer Date: Thu, 30 Jan 2025 14:02:13 +0100 Subject: [PATCH 17/19] Remove unused import --- app/pipeline/chat/exercise_chat_agent_pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/app/pipeline/chat/exercise_chat_agent_pipeline.py b/app/pipeline/chat/exercise_chat_agent_pipeline.py index f3df45c0..5992cc62 100644 --- a/app/pipeline/chat/exercise_chat_agent_pipeline.py +++ b/app/pipeline/chat/exercise_chat_agent_pipeline.py @@ -43,7 +43,6 @@ from ...retrieval.faq_retrieval_utils import should_allow_faq_tool, format_faqs from ...retrieval.lecture_retrieval import LectureRetrieval from ...vector_database.database import VectorDatabase -from ...vector_database.faq_schema import FaqSchema from ...vector_database.lecture_schema import LectureSchema from weaviate.collections.classes.filters import Filter from ...web.status.status_update import ExerciseChatStatusCallback From dd49dc0076858f03f3f9c79202b85a0065a72559 Mon Sep 17 00:00:00 2001 From: Tim Cremer Date: Thu, 30 Jan 2025 14:07:35 +0100 Subject: [PATCH 18/19] fix typo --- app/pipeline/chat/exercise_chat_agent_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/pipeline/chat/exercise_chat_agent_pipeline.py b/app/pipeline/chat/exercise_chat_agent_pipeline.py index 5992cc62..0a46f10d 100644 --- a/app/pipeline/chat/exercise_chat_agent_pipeline.py +++ b/app/pipeline/chat/exercise_chat_agent_pipeline.py @@ -377,7 +377,7 @@ def lecture_content_retrieval() -> str: Only use this once. """ self.callback.in_progress("Retrieving lecture content ...") - self.retrieved_paragraphs = self.lectureRetriever( + self.retrieved_paragraphs = self.lecture_retriever( chat_history=chat_history, student_query=query.contents[0].text_content, result_limit=5, From 0921e1aea488413e32034c96645f76a5c489387a Mon Sep 17 00:00:00 2001 From: Tim Cremer Date: Thu, 30 Jan 2025 14:33:24 +0100 Subject: [PATCH 19/19] linter --- app/pipeline/chat/course_chat_pipeline.py | 1 - app/pipeline/faq_ingestion_pipeline.py | 1 - 2 files changed, 2 deletions(-) diff --git a/app/pipeline/chat/course_chat_pipeline.py b/app/pipeline/chat/course_chat_pipeline.py index 46d558a9..1d877d68 100644 --- a/app/pipeline/chat/course_chat_pipeline.py +++ b/app/pipeline/chat/course_chat_pipeline.py @@ -46,7 +46,6 @@ from ...retrieval.faq_retrieval_utils import should_allow_faq_tool, format_faqs from ...retrieval.lecture_retrieval import LectureRetrieval from ...vector_database.database import VectorDatabase -from ...vector_database.faq_schema import FaqSchema from ...vector_database.lecture_schema import LectureSchema from ...web.status.status_update import ( CourseChatStatusCallback, diff --git a/app/pipeline/faq_ingestion_pipeline.py b/app/pipeline/faq_ingestion_pipeline.py index ecfb9d1a..b9d43c86 100644 --- a/app/pipeline/faq_ingestion_pipeline.py +++ b/app/pipeline/faq_ingestion_pipeline.py @@ -1,4 +1,3 @@ -import threading from asyncio.log import logger from typing import Optional, List, Dict from langchain_core.output_parsers import StrOutputParser