diff --git a/app/common/PipelineEnum.py b/app/common/PipelineEnum.py index 487f3fac..a3283705 100644 --- a/app/common/PipelineEnum.py +++ b/app/common/PipelineEnum.py @@ -14,6 +14,8 @@ class PipelineEnum(str, Enum): IRIS_SUMMARY_PIPELINE = "IRIS_SUMMARY_PIPELINE" IRIS_LECTURE_RETRIEVAL_PIPELINE = "IRIS_LECTURE_RETRIEVAL_PIPELINE" IRIS_LECTURE_INGESTION = "IRIS_LECTURE_INGESTION" + IRIS_FAQ_INGESTION = "IRIS_FAQ_INGESTION" + IRIS_FAQ_RETRIEVAL_PIPELINE = "IRIS_FAQ_RETRIEVAL_PIPELINE" IRIS_INCONSISTENCY_CHECK = "IRIS_INCONSISTENCY_CHECK" IRIS_REWRITING_PIPELINE = "IRIS_REWRITING_PIPELINE" NOT_SET = "NOT_SET" diff --git a/app/domain/data/faq_dto.py b/app/domain/data/faq_dto.py new file mode 100644 index 00000000..12d3fd7a --- /dev/null +++ b/app/domain/data/faq_dto.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel, Field + + +class FaqDTO(BaseModel): + faq_id: int = Field(alias="faqId") + course_id: int = Field(alias="courseId") + question_title: str = Field(alias="questionTitle") + question_answer: str = Field(alias="questionAnswer") + course_name: str = Field(default="", alias="courseName") + course_description: str = Field(default="", alias="courseDescription") diff --git a/app/domain/ingestion/deletionPipelineExecutionDto.py b/app/domain/ingestion/deletionPipelineExecutionDto.py index 1cec7cdd..0824f1d6 100644 --- a/app/domain/ingestion/deletionPipelineExecutionDto.py +++ b/app/domain/ingestion/deletionPipelineExecutionDto.py @@ -3,6 +3,7 @@ from pydantic import Field from app.domain import PipelineExecutionDTO, PipelineExecutionSettingsDTO +from app.domain.data.faq_dto import FaqDTO from app.domain.data.lecture_unit_dto import LectureUnitDTO from app.domain.status.stage_dto import StageDTO @@ -13,3 +14,11 @@ class LecturesDeletionExecutionDto(PipelineExecutionDTO): initial_stages: Optional[List[StageDTO]] = Field( default=None, alias="initialStages" ) + + +class FaqDeletionExecutionDto(PipelineExecutionDTO): + faq: FaqDTO = Field(..., alias="pyrisFaqWebhookDTO") + settings: Optional[PipelineExecutionSettingsDTO] + initial_stages: Optional[List[StageDTO]] = Field( + default=None, alias="initialStages" + ) diff --git a/app/domain/ingestion/ingestion_pipeline_execution_dto.py b/app/domain/ingestion/ingestion_pipeline_execution_dto.py index 12f3205f..d97fda45 100644 --- a/app/domain/ingestion/ingestion_pipeline_execution_dto.py +++ b/app/domain/ingestion/ingestion_pipeline_execution_dto.py @@ -3,6 +3,7 @@ from pydantic import Field from app.domain import PipelineExecutionDTO, PipelineExecutionSettingsDTO +from app.domain.data.faq_dto import FaqDTO from app.domain.data.lecture_unit_dto import LectureUnitDTO from app.domain.status.stage_dto import StageDTO @@ -13,3 +14,11 @@ class IngestionPipelineExecutionDto(PipelineExecutionDTO): initial_stages: Optional[List[StageDTO]] = Field( default=None, alias="initialStages" ) + + +class FaqIngestionPipelineExecutionDto(PipelineExecutionDTO): + faq: FaqDTO = Field(..., alias="pyrisFaqWebhookDTO") + settings: Optional[PipelineExecutionSettingsDTO] + initial_stages: Optional[List[StageDTO]] = Field( + default=None, alias="initialStages" + ) diff --git a/app/pipeline/chat/course_chat_pipeline.py b/app/pipeline/chat/course_chat_pipeline.py index f0904116..1d877d68 100644 --- a/app/pipeline/chat/course_chat_pipeline.py +++ b/app/pipeline/chat/course_chat_pipeline.py @@ -20,7 +20,7 @@ InteractionSuggestionPipeline, ) from .lecture_chat_pipeline import LectureChatPipeline -from ..shared.citation_pipeline import CitationPipeline +from ..shared.citation_pipeline import CitationPipeline, InformationType from ..shared.utils import generate_structured_tools_from_functions from ...common.message_converters import convert_iris_message_to_langchain_message from ...common.pyris_message import PyrisMessage @@ -42,6 +42,8 @@ ) from ...domain import CourseChatPipelineExecutionDTO from app.common.PipelineEnum import PipelineEnum +from ...retrieval.faq_retrieval import FaqRetrieval +from ...retrieval.faq_retrieval_utils import should_allow_faq_tool, format_faqs from ...retrieval.lecture_retrieval import LectureRetrieval from ...vector_database.database import VectorDatabase from ...vector_database.lecture_schema import LectureSchema @@ -81,6 +83,7 @@ class CourseChatPipeline(Pipeline): variant: str event: str | None retrieved_paragraphs: List[dict] = None + retrieved_faqs: List[dict] = None def __init__( self, @@ -114,7 +117,8 @@ def __init__( self.callback = callback self.db = VectorDatabase() - self.retriever = LectureRetrieval(self.db.client) + self.lecture_retriever = LectureRetrieval(self.db.client) + self.faq_retriever = FaqRetrieval(self.db.client) self.suggestion_pipeline = InteractionSuggestionPipeline(variant="course") self.citation_pipeline = CitationPipeline() @@ -282,7 +286,7 @@ def lecture_content_retrieval() -> str: Only use this once. """ self.callback.in_progress("Retrieving lecture content ...") - self.retrieved_paragraphs = self.retriever( + self.retrieved_paragraphs = self.lecture_retriever( chat_history=history, student_query=query.contents[0].text_content, result_limit=5, @@ -302,6 +306,31 @@ def lecture_content_retrieval() -> str: result += lct return result + def faq_content_retrieval() -> str: + """ + Use this tool to retrieve information from indexed FAQs. + It is suitable when no other tool fits, it is a common question or the question is frequently asked, + or the question could be effectively answered by an FAQ. Also use this if the question is explicitly + organizational and course-related. An organizational question about the course might be + "What is the course structure?" or "How do I enroll?" or exam related content like "When is the exam". + The tool performs a RAG retrieval based on the chat history to find the most relevant FAQs. + Each FAQ follows this format: FAQ ID, FAQ Question, FAQ Answer. + Respond to the query concisely and solely using the answer from the relevant FAQs. + This tool should only be used once per query. + """ + self.callback.in_progress("Retrieving faq content ...") + self.retrieved_faqs = self.faq_retriever( + chat_history=history, + student_query=query.contents[0].text_content, + result_limit=10, + course_name=dto.course.name, + course_id=dto.course.id, + base_url=dto.settings.artemis_base_url, + ) + + result = format_faqs(self.retrieved_faqs) + return result + if dto.user.id % 3 < 2: iris_initial_system_prompt = tell_iris_initial_system_prompt begin_agent_prompt = tell_begin_agent_prompt @@ -400,6 +429,9 @@ def lecture_content_retrieval() -> str: if self.should_allow_lecture_tool(dto.course.id): tool_list.append(lecture_content_retrieval) + if should_allow_faq_tool(self.db, dto.course.id): + tool_list.append(faq_content_retrieval) + tools = generate_structured_tools_from_functions(tool_list) # No idea why we need this extra contrary to exercise chat agent in this case, but solves the issue. params.update({"tools": tools}) @@ -420,9 +452,19 @@ def lecture_content_retrieval() -> str: if self.retrieved_paragraphs: self.callback.in_progress("Augmenting response ...") - out = self.citation_pipeline(self.retrieved_paragraphs, out) + out = self.citation_pipeline( + self.retrieved_paragraphs, out, InformationType.PARAGRAPHS + ) self.tokens.extend(self.citation_pipeline.tokens) + if self.retrieved_faqs: + self.callback.in_progress("Augmenting response ...") + out = self.citation_pipeline( + self.retrieved_faqs, + out, + InformationType.FAQS, + base_url=dto.settings.artemis_base_url, + ) self.callback.done("Response created", final_result=out, tokens=self.tokens) # try: diff --git a/app/pipeline/chat/exercise_chat_agent_pipeline.py b/app/pipeline/chat/exercise_chat_agent_pipeline.py index 920d7c64..0a46f10d 100644 --- a/app/pipeline/chat/exercise_chat_agent_pipeline.py +++ b/app/pipeline/chat/exercise_chat_agent_pipeline.py @@ -39,6 +39,8 @@ from ...llm import CapabilityRequestHandler, RequirementList from ...llm import CompletionArguments from ...llm.langchain import IrisLangchainChatModel +from ...retrieval.faq_retrieval import FaqRetrieval +from ...retrieval.faq_retrieval_utils import should_allow_faq_tool, format_faqs from ...retrieval.lecture_retrieval import LectureRetrieval from ...vector_database.database import VectorDatabase from ...vector_database.lecture_schema import LectureSchema @@ -103,6 +105,7 @@ class ExerciseChatAgentPipeline(Pipeline): prompt: ChatPromptTemplate variant: str event: str | None + retrieved_faqs: List[dict] = None def __init__( self, @@ -136,7 +139,8 @@ def __init__( # Create the pipelines self.db = VectorDatabase() self.suggestion_pipeline = InteractionSuggestionPipeline(variant="exercise") - self.retriever = LectureRetrieval(self.db.client) + self.lecture_retriever = LectureRetrieval(self.db.client) + self.faq_retriever = FaqRetrieval(self.db.client) self.reranker_pipeline = RerankerPipeline() self.code_feedback_pipeline = CodeFeedbackPipeline() self.pipeline = self.llm_big | JsonOutputParser() @@ -373,7 +377,7 @@ def lecture_content_retrieval() -> str: Only use this once. """ self.callback.in_progress("Retrieving lecture content ...") - self.retrieved_paragraphs = self.retriever( + self.retrieved_paragraphs = self.lecture_retriever( chat_history=chat_history, student_query=query.contents[0].text_content, result_limit=5, @@ -393,6 +397,31 @@ def lecture_content_retrieval() -> str: result += lct return result + def faq_content_retrieval() -> str: + """ + Use this tool to retrieve information from indexed FAQs. + It is suitable when no other tool fits, it is a common question or the question is frequently asked, + or the question could be effectively answered by an FAQ. Also use this if the question is explicitly + organizational and course-related. An organizational question about the course might be + "What is the course structure?" or "How do I enroll?" or exam related content like "When is the exam". + The tool performs a RAG retrieval based on the chat history to find the most relevant FAQs. + Each FAQ follows this format: FAQ ID, FAQ Question, FAQ Answer. + Respond to the query concisely and solely using the answer from the relevant FAQs. + This tool should only be used once per query. + """ + self.callback.in_progress("Retrieving faq content ...") + self.retrieved_faqs = self.faq_retriever( + chat_history=chat_history, + student_query=query.contents[0].text_content, + result_limit=10, + course_name=dto.course.name, + course_id=dto.course.id, + base_url=dto.settings.artemis_base_url, + ) + + result = format_faqs(self.retrieved_faqs) + return result + iris_initial_system_prompt = tell_iris_initial_system_prompt chat_history_exists_prompt = tell_chat_history_exists_prompt no_chat_history_prompt = tell_no_chat_history_prompt @@ -511,6 +540,10 @@ def lecture_content_retrieval() -> str: ] if self.should_allow_lecture_tool(dto.course.id): tool_list.append(lecture_content_retrieval) + + if should_allow_faq_tool(self.db, dto.course.id): + tool_list.append(faq_content_retrieval) + tools = generate_structured_tools_from_functions(tool_list) agent = create_tool_calling_agent( llm=self.llm_big, tools=tools, prompt=self.prompt diff --git a/app/pipeline/faq_ingestion_pipeline.py b/app/pipeline/faq_ingestion_pipeline.py new file mode 100644 index 00000000..b9d43c86 --- /dev/null +++ b/app/pipeline/faq_ingestion_pipeline.py @@ -0,0 +1,140 @@ +from asyncio.log import logger +from typing import Optional, List, Dict +from langchain_core.output_parsers import StrOutputParser +from weaviate import WeaviateClient +from weaviate.classes.query import Filter +from . import Pipeline +from .lecture_ingestion_pipeline import batch_update_lock +from ..domain.data.faq_dto import FaqDTO + +from app.domain.ingestion.ingestion_pipeline_execution_dto import ( + FaqIngestionPipelineExecutionDto, +) +from ..llm.langchain import IrisLangchainChatModel +from ..vector_database.faq_schema import FaqSchema, init_faq_schema +from ..ingestion.abstract_ingestion import AbstractIngestion +from ..llm import ( + BasicRequestHandler, + CompletionArguments, + CapabilityRequestHandler, + RequirementList, +) +from ..web.status.faq_ingestion_status_callback import FaqIngestionStatus + +# we use the same lock as the lecture ingestion pipeline +batch_update_lock = batch_update_lock + + +class FaqIngestionPipeline(AbstractIngestion, Pipeline): + + def __init__( + self, + client: WeaviateClient, + dto: Optional[FaqIngestionPipelineExecutionDto], + callback: FaqIngestionStatus, + ): + super().__init__() + self.client = client + self.collection = init_faq_schema(client) + self.dto = dto + self.llm_embedding = BasicRequestHandler("embedding-small") + self.callback = callback + request_handler = CapabilityRequestHandler( + requirements=RequirementList( + gpt_version_equivalent=4.25, + context_length=16385, + privacy_compliance=True, + ) + ) + completion_args = CompletionArguments(temperature=0.2, max_tokens=2000) + self.llm = IrisLangchainChatModel( + request_handler=request_handler, completion_args=completion_args + ) + self.pipeline = self.llm | StrOutputParser() + self.tokens = [] + + def __call__(self) -> bool: + try: + self.callback.in_progress("Deleting old faq from database...") + self.delete_faq( + self.dto.faq.faq_id, + self.dto.faq.course_id, + ) + self.callback.done("Old faq removed") + self.callback.in_progress("Ingesting faqs into database...") + self.batch_update(self.dto.faq) + self.callback.done("Faq Ingestion Finished", tokens=self.tokens) + logger.info( + f"Faq ingestion pipeline finished Successfully for faq: {self.dto.faq.faq_id}" + ) + return True + except Exception as e: + logger.error(f"Error updating faq: {e}") + self.callback.error( + f"Failed to faq into the database: {e}", + exception=e, + tokens=self.tokens, + ) + return False + + def batch_update(self, faq: FaqDTO): + """ + Batch update the faq into the database + This method is thread-safe and can only be executed by one thread at a time. + Weaviate limitation. + """ + global batch_update_lock + with batch_update_lock: + with self.collection.batch.rate_limit(requests_per_minute=600) as batch: + try: + embed_chunk = self.llm_embedding.embed( + f"{faq.question_title} : {faq.question_answer}" + ) + faq_dict = faq.model_dump() + + batch.add_object(properties=faq_dict, vector=embed_chunk) + + except Exception as e: + logger.error(f"Error updating faq: {e}") + self.callback.error( + f"Failed to ingest faqs into the database: {e}", + exception=e, + tokens=self.tokens, + ) + + def delete_old_faqs(self, faqs: list[FaqDTO]): + """ + Delete the faq from the database + """ + try: + for faq in faqs: + if self.delete_faq(faq.faq_id, faq.course_id): + logger.info("Faq deleted successfully") + else: + logger.error("Failed to delete faq") + self.callback.done("Old faqs removed") + except Exception as e: + logger.error(f"Error deleting faqs: {e}") + self.callback.error("Error while removing old faqs") + return False + + def delete_faq(self, faq_id, course_id): + """ + Delete the faq from the database + """ + try: + self.collection.data.delete_many( + where=Filter.by_property(FaqSchema.FAQ_ID.value).equal(faq_id) + & Filter.by_property(FaqSchema.COURSE_ID.value).equal(course_id) + ) + logger.info(f"successfully deleted faq with id {faq_id}") + return True + except Exception as e: + logger.error(f"Error deleting faq: {e}", exc_info=True) + return False + + def chunk_data(self, path: str) -> List[Dict[str, str]]: + """ + Faqs are so small, they do not need to be chunked into smaller parts + """ + return diff --git a/app/pipeline/prompts/faq_citation_prompt.txt b/app/pipeline/prompts/faq_citation_prompt.txt new file mode 100644 index 00000000..681bb05c --- /dev/null +++ b/app/pipeline/prompts/faq_citation_prompt.txt @@ -0,0 +1,32 @@ +In the paragraphs below you are provided with an answer to a question. Underneath the answer you will find the faqs that the answer was based on. +Add citations of the faqs to the answer. Cite the faqs in brackets after the sentence where the information is used in the answer. +At the end of the answer, list each source with its corresponding number and provide the FAQ Question title, and a clickable link in this format: [1] "FAQ Question title". +Do not include the actual faqs, only the citations at the end. +Please do not use the FAQ ID as the citation number, instead, use the order of the citations in the answer. +Only include the citations of the faqs that are relevant to the answer. +If the answer actually does not contain any information from the faqs, please do not include any citations and return '!NONE!'. +But if the answer contains information from the paragraphs, ALWAYS include citations. + +Here is an example how to rewrite the answer with citations (ONLY ADD CITATION IF THE PROVIDED FAQS ARE RELEVANT TO THE ANSWER): +" +Lorem ipsum dolor sit amet, consectetur adipiscing elit [1]. Ded do eiusmod tempor incididunt ut labore et dolore magna aliqua [2]. + +[1] FAQ question title 1. +[2] FAQ question title 2. +" + +Note: If there is no link available, please do not include the link in the citation. For example, if citation 1 does not have a link, it should look like this: +[1] "FAQ question title" +but if citation 2 has a link, it should look like this: +[2] "FAQ question title" + +Here are the answer and the faqs: + +Answer without citations: +{Answer} + +Faqs with their FAQ ID, CourseId, FAQ Question title and FAQ Question Answer and the Link to the FAQ: +{Paragraphs} + +Answer with citations (ensure empty line between the message and the citations): +If the answer actually does not contain any information from the paragraphs, please do not include any citations and return '!NONE!'. diff --git a/app/pipeline/prompts/faq_retrieval_prompts.py b/app/pipeline/prompts/faq_retrieval_prompts.py new file mode 100644 index 00000000..0b272f0a --- /dev/null +++ b/app/pipeline/prompts/faq_retrieval_prompts.py @@ -0,0 +1,19 @@ +faq_retriever_initial_prompt = """ +You write good and performant vector database queries, in particular for Weaviate, +from chat histories between an AI tutor and a student. +The query should be designed to retrieve context information from indexed faqs so the AI tutor +can use the context information to give a better answer. Apply accepted norms when querying vector databases. +Query the database so it returns answers for the latest student query. +A good vector database query is formulated in natural language, just like a student would ask a question. +It is not an instruction to the database, but a question to the database. +The chat history between the AI tutor and the student is provided to you in the next messages. +""" + +write_hypothetical_answer_prompt = """ +A student has sent a query in the context the course {course_name}. +The chat history between the AI tutor and the student is provided to you in the next messages. +Please provide a response in {course_language}. +You should create a response that looks like a faq answer. +Craft your response to closely reflect the style and content of typical university lecture materials. +Do not exceed 350 words. Add keywords and phrases that are relevant to student intent. +""" diff --git a/app/pipeline/shared/citation_pipeline.py b/app/pipeline/shared/citation_pipeline.py index fc71016b..d89efbd5 100644 --- a/app/pipeline/shared/citation_pipeline.py +++ b/app/pipeline/shared/citation_pipeline.py @@ -1,5 +1,6 @@ import os from asyncio.log import logger +from enum import Enum from typing import List, Union from langchain_core.output_parsers import StrOutputParser @@ -10,10 +11,16 @@ from app.common.PipelineEnum import PipelineEnum from app.llm.langchain import IrisLangchainChatModel from app.pipeline import Pipeline +from app.vector_database.faq_schema import FaqSchema from app.vector_database.lecture_schema import LectureSchema +class InformationType(str, Enum): + PARAGRAPHS = "PARAGRAPHS" + FAQS = "FAQS" + + class CitationPipeline(Pipeline): """A generic reranker pipeline that can be used to rerank a list of documents based on a question""" @@ -37,7 +44,12 @@ def __init__(self): dirname = os.path.dirname(__file__) prompt_file_path = os.path.join(dirname, "..", "prompts", "citation_prompt.txt") with open(prompt_file_path, "r") as file: - self.prompt_str = file.read() + self.lecture_prompt_str = file.read() + prompt_file_path = os.path.join( + dirname, "..", "prompts", "faq_citation_prompt.txt" + ) + with open(prompt_file_path, "r") as file: + self.faq_prompt_str = file.read() self.pipeline = self.llm | StrOutputParser() self.tokens = [] @@ -47,7 +59,7 @@ def __repr__(self): def __str__(self): return f"{self.__class__.__name__}(llm={self.llm})" - def create_formatted_string(self, paragraphs): + def create_formatted_lecture_string(self, paragraphs): """ Create a formatted string from the data """ @@ -65,19 +77,47 @@ def create_formatted_string(self, paragraphs): return formatted_string.replace("{", "{{").replace("}", "}}") + def create_formatted_faq_string(self, faqs, base_url): + """ + Create a formatted string from the data + """ + formatted_string = "" + for i, faq in enumerate(faqs): + faq = "FAQ ID {}, CourseId {} , FAQ Question title {} and FAQ Question Answer {} and FAQ link {}".format( + faq.get(FaqSchema.FAQ_ID.value), + faq.get(FaqSchema.COURSE_ID.value), + faq.get(FaqSchema.QUESTION_TITLE.value), + faq.get(FaqSchema.QUESTION_ANSWER.value), + f"{base_url}/courses/{faq.get(FaqSchema.COURSE_ID.value)}/faq/?faqId={faq.get(FaqSchema.FAQ_ID.value)}", + ) + formatted_string += faq + + return formatted_string.replace("{", "{{").replace("}", "}}") + def __call__( self, - paragraphs: Union[List[dict], List[str]], + information: Union[List[dict], List[str]], answer: str, + information_type: InformationType = InformationType.PARAGRAPHS, **kwargs, ) -> str: """ Runs the pipeline - :param paragraphs: List of paragraphs which can be list of dicts or list of strings + :param information: List of info as list of dicts or strings to augment response :param query: The query + :param information_type: The type of information provided. can be either lectures or faqs :return: Selected file content """ - paras = self.create_formatted_string(paragraphs) + paras = "" + + if information_type == InformationType.FAQS: + paras = self.create_formatted_faq_string( + information, kwargs.get("base_url") + ) + self.prompt_str = self.faq_prompt_str + if information_type == InformationType.PARAGRAPHS: + paras = self.create_formatted_lecture_string(information) + self.prompt_str = self.lecture_prompt_str try: self.default_prompt = PromptTemplate( @@ -90,7 +130,6 @@ def __call__( self._append_tokens(self.llm.tokens, PipelineEnum.IRIS_CITATION_PIPELINE) if response == "!NONE!": return answer - print(response) return response except Exception as e: logger.error("citation pipeline failed", e) diff --git a/app/retrieval/basic_retrieval.py b/app/retrieval/basic_retrieval.py new file mode 100644 index 00000000..132e7a40 --- /dev/null +++ b/app/retrieval/basic_retrieval.py @@ -0,0 +1,298 @@ +from abc import abstractmethod, ABC +from typing import List, Optional +from langsmith import traceable +from weaviate import WeaviateClient +from weaviate.classes.query import Filter +from app.common.token_usage_dto import TokenUsageDTO +from app.common.PipelineEnum import PipelineEnum +from ..common.message_converters import convert_iris_message_to_langchain_message +from ..common.pyris_message import PyrisMessage +from ..llm.langchain import IrisLangchainChatModel +from ..pipeline import Pipeline +from app.llm import ( + BasicRequestHandler, + CompletionArguments, + CapabilityRequestHandler, + RequirementList, +) +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate +import concurrent.futures +import logging + +logger = logging.getLogger(__name__) + + +def merge_retrieved_chunks( + basic_retrieved_faq_chunks, hyde_retrieved_faq_chunks +) -> List[dict]: + """ + Merge the retrieved chunks from the basic and hyde retrieval methods. This function ensures that for any + duplicate IDs, the properties from hyde_retrieved_faq_chunks will overwrite those from + basic_retrieved_faq_chunks. + """ + merged_chunks = {} + for chunk in basic_retrieved_faq_chunks: + merged_chunks[chunk["id"]] = chunk["properties"] + + for chunk in hyde_retrieved_faq_chunks: + merged_chunks[chunk["id"]] = chunk["properties"] + + return [properties for uuid, properties in merged_chunks.items()] + + +def _add_last_four_messages_to_prompt( + prompt, + chat_history: List[PyrisMessage], +): + """ + Adds the chat history and user question to the prompt + :param chat_history: The chat history + :param user_question: The user question + :return: The prompt with the chat history + """ + if chat_history is not None and len(chat_history) > 0: + num_messages_to_take = min(len(chat_history), 4) + last_messages = chat_history[-num_messages_to_take:] + chat_history_messages = [ + convert_iris_message_to_langchain_message(message) + for message in last_messages + ] + prompt += chat_history_messages + return prompt + + +class BaseRetrieval(Pipeline, ABC): + """ + Base class for retrieval pipelines. + """ + + tokens: List[TokenUsageDTO] + + @abstractmethod + def __call__(self, *args, **kwargs): + """Muss in der konkreten Implementierung überschrieben werden""" + pass + + def __init__(self, client: WeaviateClient, schema_init_func, **kwargs): + super().__init__( + implementation_id=kwargs.get("implementation_id", "base_retrieval_pipeline") + ) + request_handler = CapabilityRequestHandler( + requirements=RequirementList( + gpt_version_equivalent=4.25, + context_length=16385, + privacy_compliance=True, + ) + ) + completion_args = CompletionArguments(temperature=0, max_tokens=2000) + self.llm = IrisLangchainChatModel( + request_handler=request_handler, completion_args=completion_args + ) + self.llm_embedding = BasicRequestHandler("embedding-small") + self.pipeline = self.llm | StrOutputParser() + self.collection = schema_init_func(client) + self.tokens = [] + + @traceable(name="Retrieval: Question Assessment") + def assess_question( + self, + chat_history: list[PyrisMessage], + student_query: str, + assessment_prompt: str, + assessment_prompt_final: str, + ) -> bool: + prompt = ChatPromptTemplate.from_messages( + [ + ("system", assessment_prompt), + ] + ) + prompt = _add_last_four_messages_to_prompt(prompt, chat_history) + prompt += ChatPromptTemplate.from_messages( + [ + ("user", student_query), + ] + ) + prompt += ChatPromptTemplate.from_messages( + [ + ("system", assessment_prompt_final), + ] + ) + + try: + response = (prompt | self.pipeline).invoke({}) + logger.info(f"Response from assessment pipeline: {response}") + return response == "YES" + except Exception as e: + raise e + + @traceable(name="Retrieval: Rewrite Student Query") + def rewrite_student_query( + self, + chat_history: list[PyrisMessage], + student_query: str, + course_language: str, + course_name: str, + initial_prompt: str, + rewrite_prompt: str, + pipeline_enum: PipelineEnum, + ) -> str: + """ + Rewrite the student query. + """ + prompt = ChatPromptTemplate.from_messages( + [ + ("system", initial_prompt), + ] + ) + prompt = _add_last_four_messages_to_prompt(prompt, chat_history) + prompt += SystemMessagePromptTemplate.from_template(rewrite_prompt) + prompt_val = prompt.format_messages( + course_language=course_language, + course_name=course_name, + student_query=student_query, + ) + prompt = ChatPromptTemplate.from_messages(prompt_val) + try: + response = (prompt | self.pipeline).invoke({}) + token_usage = self.llm.tokens + token_usage.pipeline = pipeline_enum + self.tokens.append(self.llm.tokens) + return response + except Exception as e: + raise e + + @traceable(name="Retrieval: Search in DB") + def search_in_db( + self, + query: str, + hybrid_factor: float, + result_limit: int, + schema_properties: List[str], + course_id: Optional[int] = None, + base_url: Optional[str] = None, + course_id_property: str = "course_id", + base_url_property: str = "base_url", + ): + """ + Search the database for the given query. + """ + logger.info(f"Searching in the database for query: {query}") + filter_weaviate = None + + if course_id: + filter_weaviate = Filter.by_property(course_id_property).equal(course_id) + if base_url: + filter_weaviate &= Filter.by_property(base_url_property).equal(base_url) + + vec = self.llm_embedding.embed(query) + return self.collection.query.hybrid( + query=query, + alpha=hybrid_factor, + vector=vec, + return_properties=schema_properties, + limit=result_limit, + filters=filter_weaviate, + ) + + @traceable(name="Retrieval: Run Parallel Rewrite Tasks") + def run_parallel_rewrite_tasks( + self, + chat_history: list[PyrisMessage], + student_query: str, + result_limit: int, + course_language: str, + initial_prompt: str, + rewrite_prompt: str, + hypothetical_answer_prompt: str, + pipeline_enum: PipelineEnum, + course_name: Optional[str] = None, + course_id: Optional[int] = None, + base_url: Optional[str] = None, + problem_statement: Optional[str] = None, + exercise_title: Optional[str] = None, + ): + """ + Run the rewrite tasks in parallel. + """ + with concurrent.futures.ThreadPoolExecutor() as executor: + rewritten_query_future = executor.submit( + self.rewrite_student_query, + chat_history, + student_query, + course_language, + course_name, + initial_prompt, + rewrite_prompt, + pipeline_enum, + ) + hypothetical_answer_query_future = executor.submit( + self.rewrite_student_query, + chat_history, + student_query, + course_language, + course_name, + initial_prompt, + hypothetical_answer_prompt, + pipeline_enum, + ) + + rewritten_query = rewritten_query_future.result() + hypothetical_answer_query = hypothetical_answer_query_future.result() + + with concurrent.futures.ThreadPoolExecutor() as executor: + response_future = executor.submit( + self.search_in_db, + query=rewritten_query, + hybrid_factor=0.9, + result_limit=result_limit, + schema_properties=self.get_schema_properties(), + course_id=course_id, + base_url=base_url, + ) + response_hyde_future = executor.submit( + self.search_in_db, + query=hypothetical_answer_query, + hybrid_factor=0.9, + result_limit=result_limit, + schema_properties=self.get_schema_properties(), + course_id=course_id, + base_url=base_url, + ) + + response = response_future.result() + response_hyde = response_hyde_future.result() + + return response, response_hyde + + @abstractmethod + def get_schema_properties(self) -> List[str]: + """ + Abstract method to be implemented by subclasses to return the schema properties. + """ + raise NotImplementedError + + def fetch_course_language( + self, course_id: int, course_language_property: str = "course_language" + ) -> str: + """ + Fetch the language of the course based on the course ID. + If no specific language is set, it defaults to English. + """ + course_language = "english" + + if course_id: + result = self.collection.query.fetch_objects( + filters=Filter.by_property("course_id").equal(course_id), + limit=1, + return_properties=[course_language_property], + ) + + if result.objects: + fetched_language = result.objects[0].properties.get( + course_language_property + ) + if fetched_language: + course_language = fetched_language + + return course_language diff --git a/app/retrieval/faq_retrieval.py b/app/retrieval/faq_retrieval.py new file mode 100644 index 00000000..c2f9a455 --- /dev/null +++ b/app/retrieval/faq_retrieval.py @@ -0,0 +1,69 @@ +import logging +from typing import List +from langsmith import traceable +from weaviate import WeaviateClient +from app.common.PipelineEnum import PipelineEnum +from .basic_retrieval import BaseRetrieval, merge_retrieved_chunks +from ..common.pyris_message import PyrisMessage +from ..pipeline.prompts.faq_retrieval_prompts import ( + faq_retriever_initial_prompt, + write_hypothetical_answer_prompt, +) +from ..pipeline.prompts.lecture_retrieval_prompts import ( + rewrite_student_query_prompt, +) +from ..vector_database.faq_schema import FaqSchema, init_faq_schema + +logger = logging.getLogger(__name__) + + +class FaqRetrieval(BaseRetrieval): + def __init__(self, client: WeaviateClient, **kwargs): + super().__init__( + client, init_faq_schema, implementation_id="faq_retrieval_pipeline" + ) + + def get_schema_properties(self) -> List[str]: + return [ + FaqSchema.COURSE_ID.value, + FaqSchema.FAQ_ID.value, + FaqSchema.QUESTION_TITLE.value, + FaqSchema.QUESTION_ANSWER.value, + ] + + @traceable(name="Full Faq Retrieval") + def __call__( + self, + chat_history: list[PyrisMessage], + student_query: str, + result_limit: int, + course_name: str = None, + course_id: int = None, + problem_statement: str = None, + exercise_title: str = None, + base_url: str = None, + ) -> List[dict]: + course_language = self.fetch_course_language(course_id) + + response, response_hyde = self.run_parallel_rewrite_tasks( + chat_history=chat_history, + student_query=student_query, + result_limit=result_limit, + course_language=course_language, + initial_prompt=faq_retriever_initial_prompt, + rewrite_prompt=rewrite_student_query_prompt, + hypothetical_answer_prompt=write_hypothetical_answer_prompt, + pipeline_enum=PipelineEnum.IRIS_FAQ_RETRIEVAL_PIPELINE, + course_name=course_name, + course_id=course_id, + ) + + basic_retrieved_faqs: list[dict[str, dict]] = [ + {"id": obj.uuid.int, "properties": obj.properties} + for obj in response.objects + ] + hyde_retrieved_faqs: list[dict[str, dict]] = [ + {"id": obj.uuid.int, "properties": obj.properties} + for obj in response_hyde.objects + ] + return merge_retrieved_chunks(basic_retrieved_faqs, hyde_retrieved_faqs) diff --git a/app/retrieval/faq_retrieval_utils.py b/app/retrieval/faq_retrieval_utils.py new file mode 100644 index 00000000..2a65873b --- /dev/null +++ b/app/retrieval/faq_retrieval_utils.py @@ -0,0 +1,40 @@ +from weaviate.collections.classes.filters import Filter +from app.vector_database.database import VectorDatabase +from app.vector_database.faq_schema import FaqSchema + + +def should_allow_faq_tool(db: VectorDatabase, course_id: int) -> bool: + """ + Checks if there are indexed faqs for the given course + + :param db: The vector database on which the faqs are indexed + :param course_id: The course ID + :return: True if there are indexed faqs for the course, False otherwise + """ + if course_id: + # Fetch the first object that matches the course ID with the language property + result = db.faqs.query.fetch_objects( + filters=Filter.by_property(FaqSchema.COURSE_ID.value).equal(course_id), + limit=1, + return_properties=[FaqSchema.COURSE_NAME.value], + ) + return len(result.objects) > 0 + return False + + +def format_faqs(retrieved_faqs): + """ + Formatiert die abgerufenen FAQs in einen String. + + :param retrieved_faqs: Liste der abgerufenen FAQs + :return: Formatierter String mit den FAQ-Daten + """ + result = "" + for faq in retrieved_faqs: + res = "[FAQ ID: {}, FAQ Question: {}, FAQ Answer: {}]".format( + faq.get(FaqSchema.FAQ_ID.value), + faq.get(FaqSchema.QUESTION_TITLE.value), + faq.get(FaqSchema.QUESTION_ANSWER.value), + ) + result += res + return result diff --git a/app/vector_database/database.py b/app/vector_database/database.py index e4dedcd0..5c3ed7cf 100644 --- a/app/vector_database/database.py +++ b/app/vector_database/database.py @@ -1,5 +1,7 @@ import logging import weaviate + +from .faq_schema import init_faq_schema from .lecture_schema import init_lecture_schema from weaviate.classes.query import Filter from app.config import settings @@ -27,6 +29,7 @@ def __init__(self): logger.info("Weaviate client initialized") self.client = VectorDatabase._client_instance self.lectures = init_lecture_schema(self.client) + self.faqs = init_faq_schema(self.client) def delete_collection(self, collection_name): """ diff --git a/app/vector_database/faq_schema.py b/app/vector_database/faq_schema.py new file mode 100644 index 00000000..abf97023 --- /dev/null +++ b/app/vector_database/faq_schema.py @@ -0,0 +1,93 @@ +from enum import Enum + +from weaviate.classes.config import Property +from weaviate import WeaviateClient +from weaviate.collections import Collection +from weaviate.collections.classes.config import Configure, VectorDistances, DataType + + +class FaqSchema(Enum): + """ + Schema for the faqs + """ + + COLLECTION_NAME = "Faqs" + COURSE_NAME = "course_name" + COURSE_DESCRIPTION = "course_description" + COURSE_LANGUAGE = "course_language" + COURSE_ID = "course_id" + FAQ_ID = "faq_id" + QUESTION_TITLE = "question_title" + QUESTION_ANSWER = "question_answer" + + +def init_faq_schema(client: WeaviateClient) -> Collection: + """ + Initialize the schema for the faqs + """ + if client.collections.exists(FaqSchema.COLLECTION_NAME.value): + collection = client.collections.get(FaqSchema.COLLECTION_NAME.value) + # Check and add 'course_language' property if missing + if not any( + property.name == FaqSchema.COURSE_LANGUAGE.value + for property in collection.config.get(simple=False).properties + ): + collection.config.add_property( + Property( + name=FaqSchema.COURSE_LANGUAGE.value, + description="The language of the COURSE", + data_type=DataType.TEXT, + index_searchable=False, + ) + ) + return collection + + return client.collections.create( + name=FaqSchema.COLLECTION_NAME.value, + vectorizer_config=Configure.Vectorizer.none(), + vector_index_config=Configure.VectorIndex.hnsw( + distance_metric=VectorDistances.COSINE + ), + properties=[ + Property( + name=FaqSchema.COURSE_ID.value, + description="The ID of the course", + data_type=DataType.INT, + index_searchable=False, + ), + Property( + name=FaqSchema.COURSE_NAME.value, + description="The name of the course", + data_type=DataType.TEXT, + index_searchable=False, + ), + Property( + name=FaqSchema.COURSE_DESCRIPTION.value, + description="The description of the COURSE", + data_type=DataType.TEXT, + index_searchable=False, + ), + Property( + name=FaqSchema.COURSE_LANGUAGE.value, + description="The language of the COURSE", + data_type=DataType.TEXT, + index_searchable=False, + ), + Property( + name=FaqSchema.FAQ_ID.value, + description="The ID of the Faq", + data_type=DataType.INT, + index_searchable=False, + ), + Property( + name=FaqSchema.QUESTION_TITLE.value, + description="The title of the faq", + data_type=DataType.TEXT, + ), + Property( + name=FaqSchema.QUESTION_ANSWER.value, + description="The answer of the faq", + data_type=DataType.TEXT, + ), + ], + ) diff --git a/app/web/routers/pipelines.py b/app/web/routers/pipelines.py index d3fcc216..b7c87fac 100644 --- a/app/web/routers/pipelines.py +++ b/app/web/routers/pipelines.py @@ -430,5 +430,15 @@ def get_pipeline(feature: str): description="Default ChatGPT wrapper variant.", ) ] + + case "FAQ_INGESTION": + return [ + FeatureDTO( + id="default", + name="Default Variant", + description="Default faq ingestion variant.", + ) + ] + case _: return Response(status_code=status.HTTP_400_BAD_REQUEST) diff --git a/app/web/routers/webhooks.py b/app/web/routers/webhooks.py index 739a9bbb..3b49f3d1 100644 --- a/app/web/routers/webhooks.py +++ b/app/web/routers/webhooks.py @@ -8,12 +8,16 @@ from app.dependencies import TokenValidator from app.domain.ingestion.ingestion_pipeline_execution_dto import ( IngestionPipelineExecutionDto, + FaqIngestionPipelineExecutionDto, ) +from ..status.faq_ingestion_status_callback import FaqIngestionStatus from ..status.ingestion_status_callback import IngestionStatusCallback from ..status.lecture_deletion_status_callback import LecturesDeletionStatusCallback from ...domain.ingestion.deletionPipelineExecutionDto import ( LecturesDeletionExecutionDto, + FaqDeletionExecutionDto, ) +from ...pipeline.faq_ingestion_pipeline import FaqIngestionPipeline from ...pipeline.lecture_ingestion_pipeline import LectureIngestionPipeline from ...vector_database.database import VectorDatabase @@ -40,6 +44,7 @@ def run_lecture_update_pipeline_worker(dto: IngestionPipelineExecutionDto): client=client, dto=dto, callback=callback ) pipeline() + except Exception as e: logger.error(f"Error Ingestion pipeline: {e}") logger.error(traceback.format_exc()) @@ -67,6 +72,57 @@ def run_lecture_deletion_pipeline_worker(dto: LecturesDeletionExecutionDto): logger.error(traceback.format_exc()) +def run_faq_update_pipeline_worker(dto: FaqIngestionPipelineExecutionDto): + """ + Run the exercise chat pipeline in a separate thread + """ + with semaphore: + try: + callback = FaqIngestionStatus( + run_id=dto.settings.authentication_token, + base_url=dto.settings.artemis_base_url, + initial_stages=dto.initial_stages, + faq_id=dto.faq.faq_id, + ) + db = VectorDatabase() + client = db.get_client() + pipeline = FaqIngestionPipeline(client=client, dto=dto, callback=callback) + pipeline() + + except Exception as e: + logger.error(f"Error Faq Ingestion pipeline: {e}") + logger.error(traceback.format_exc()) + capture_exception(e) + finally: + semaphore.release() + + +def run_faq_delete_pipeline_worker(dto: FaqDeletionExecutionDto): + """ + Run the faq deletion in a separate thread + """ + with semaphore: + try: + callback = FaqIngestionStatus( + run_id=dto.settings.authentication_token, + base_url=dto.settings.artemis_base_url, + initial_stages=dto.initial_stages, + faq_id=dto.faq.faq_id, + ) + db = VectorDatabase() + client = db.get_client() + # Hier würd dann die Methode zum entfernen aus der Datenbank kommen + pipeline = FaqIngestionPipeline(client=client, dto=None, callback=callback) + pipeline.delete_faq(dto.faq.faq_id, dto.faq.course_id) + + except Exception as e: + logger.error(f"Error Ingestion pipeline: {e}") + logger.error(traceback.format_exc()) + capture_exception(e) + finally: + semaphore.release() + + @router.post( "/lectures/fullIngestion", status_code=status.HTTP_202_ACCEPTED, @@ -91,3 +147,31 @@ def lecture_deletion_webhook(dto: LecturesDeletionExecutionDto): """ thread = Thread(target=run_lecture_deletion_pipeline_worker, args=(dto,)) thread.start() + + +@router.post( + "/faqs", + status_code=status.HTTP_202_ACCEPTED, + dependencies=[Depends(TokenValidator())], +) +def faq_ingestion_webhook(dto: FaqIngestionPipelineExecutionDto): + """ + Webhook endpoint to trigger the faq ingestion pipeline + """ + thread = Thread(target=run_faq_update_pipeline_worker, args=(dto,)) + thread.start() + return + + +@router.post( + "/faqs/delete", + status_code=status.HTTP_202_ACCEPTED, + dependencies=[Depends(TokenValidator())], +) +def faq_deletion_webhook(dto: FaqDeletionExecutionDto): + """ + Webhook endpoint to trigger the faq deletion pipeline + """ + thread = Thread(target=run_faq_delete_pipeline_worker, args=(dto,)) + thread.start() + return diff --git a/app/web/status/faq_ingestion_status_callback.py b/app/web/status/faq_ingestion_status_callback.py new file mode 100644 index 00000000..8642fb5e --- /dev/null +++ b/app/web/status/faq_ingestion_status_callback.py @@ -0,0 +1,47 @@ +from typing import List + +from .status_update import StatusCallback +from ...domain.ingestion.ingestion_status_update_dto import IngestionStatusUpdateDTO +from ...domain.status.stage_state_dto import StageStateEnum +from ...domain.status.stage_dto import StageDTO +import logging + +logger = logging.getLogger(__name__) + + +class FaqIngestionStatus(StatusCallback): + """ + Callback class for updating the status of a Faq ingestion Pipeline run. + """ + + def __init__( + self, + run_id: str, + base_url: str, + initial_stages: List[StageDTO] = None, + faq_id: int = None, + ): + url = ( + f"{base_url}/api/public/pyris/webhooks/ingestion/faqs/runs/{run_id}/status" + ) + + current_stage_index = len(initial_stages) if initial_stages else 0 + stages = initial_stages or [] + stages += [ + StageDTO( + weight=10, state=StageStateEnum.NOT_STARTED, name="Old faq removal" + ), + StageDTO( + weight=30, + state=StageStateEnum.NOT_STARTED, + name="Faq Interpretation", + ), + StageDTO( + weight=60, + state=StageStateEnum.NOT_STARTED, + name="Faq ingestion", + ), + ] + status = IngestionStatusUpdateDTO(stages=stages, id=faq_id) + stage = stages[current_stage_index] + super().__init__(url, run_id, status, stage, current_stage_index)