diff --git a/main.py b/main.py index a2d903c..3c24602 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ import os +import re from dotenv import load_dotenv from slack_bolt import App from slack_bolt.adapter.fastapi import SlackRequestHandler @@ -7,6 +8,12 @@ import requests import sqlite3 from pydantic import BaseModel +from fastapi import Depends +from slack_bolt.context.ack import Ack +from slack_bolt.request import BoltRequest +from slack_bolt.response import BoltResponse +import urllib.parse +import json # Load environment variables load_dotenv() @@ -47,6 +54,74 @@ def init_db(self): ) """ ) + c.execute( + """ + CREATE TABLE IF NOT EXISTS thread_brain_mapping ( + thread_ts TEXT PRIMARY KEY, + brain_id TEXT + ) + """ + ) + c.execute( + """ + CREATE TABLE IF NOT EXISTS thread_question_mapping ( + thread_ts TEXT PRIMARY KEY, + question TEXT + ) + """ + ) + ### Create a table for saving interactive message id with thread_ts as primary key + c.execute( + """ + CREATE TABLE IF NOT EXISTS thread_interactive_mapping ( + interactive_message_id TEXT PRIMARY KEY, + thread_ts TEXT + + ) + """ + ) + conn.commit() + conn.close() + + def get_all_interactive_message_ids_for_thread(self, thread_ts): + conn = sqlite3.connect(self.config.db_name) + c = conn.cursor() + c.execute( + "SELECT interactive_message_id FROM thread_interactive_mapping WHERE thread_ts = ?", + (thread_ts,), + ) + result = c.fetchall() + conn.close() + return result + + def set_interactive_message_id(self, thread_ts, interactive_message_id): + conn = sqlite3.connect(self.config.db_name) + c = conn.cursor() + c.execute( + "INSERT OR REPLACE INTO thread_interactive_mapping VALUES (?, ?)", + (thread_ts, interactive_message_id), + ) + conn.commit() + conn.close() + + def get_question(self, thread_ts): + conn = sqlite3.connect(self.config.db_name) + c = conn.cursor() + c.execute( + "SELECT question FROM thread_question_mapping WHERE thread_ts = ?", + (thread_ts,), + ) + result = c.fetchone() + conn.close() + return result[0] if result else None + + def set_question(self, thread_ts, question): + conn = sqlite3.connect(self.config.db_name) + c = conn.cursor() + c.execute( + "INSERT OR REPLACE INTO thread_question_mapping VALUES (?, ?)", + (thread_ts, question), + ) conn.commit() conn.close() @@ -74,17 +149,44 @@ def set_chat_id(self, thread_ts, chat_id): conn.commit() conn.close() - def make_quivr_api_request(self, method, endpoint, data=None): + def get_brain_id(self, thread_ts): + conn = sqlite3.connect(self.config.db_name) + c = conn.cursor() + c.execute( + "SELECT brain_id FROM thread_brain_mapping WHERE thread_ts = ?", + (thread_ts,), + ) + result = c.fetchone() + conn.close() + return result[0] if result else None + + def set_brain_id(self, thread_ts, brain_id): + conn = sqlite3.connect(self.config.db_name) + logger.info(f"Setting brain ID: {brain_id} for thread: {thread_ts}") + c = conn.cursor() + c.execute( + "INSERT OR REPLACE INTO thread_brain_mapping VALUES (?, ?)", + (thread_ts, brain_id), + ) + conn.commit() + conn.close() + + def make_quivr_api_request(self, method, endpoint, data=None, params=None): + logger.info(f"Making Quivr API request: {method} {endpoint}") + logger.info(f"Data: {data}") + logger.info(f"Params: {params}") headers = { "accept": "application/json", "Authorization": f"Bearer {self.config.quivr_api_key}", "Content-Type": "application/json", } url = f"{self.config.quivr_api_base_url}{endpoint}" - response = requests.request(method, url, headers=headers, json=data) + response = requests.request( + method, url, headers=headers, json=data, params=params + ) return response.json() - def update_home_tab(self, client, event, logger): + def update_home_tab(self, client, event): try: client.views_publish( user_id=event["user"], @@ -122,49 +224,211 @@ def update_home_tab(self, client, event, logger): except Exception as e: logger.error(f"Error publishing home tab: {e}") - def handle_app_mentions(self, body, say, logger, client): - logger.info(body) + def handle_app_mentions(self, body, say, client): + print("Coming here") + self.app.client.reactions_add( + channel=body["event"]["channel"], + name="brain", + timestamp=body["event"]["ts"], + ) + logger.info(f"Hanlding app mention") + logger.info(f"Question is {body['event']['text']}") + logger.info(f"Thread is {body['event']['ts']}") + self.set_question(body["event"]["ts"], body["event"]["text"]) + + # Check if brain is already set for this thread + # Check if thread_ts is in the event payload, if not, use the ts from the event + thread_ts = body["event"].get("thread_ts", body["event"]["ts"]) + brain_id = self.get_brain_id(thread_ts) + logger.info(f"Brain ID: {brain_id}") + logger.info(f"2 - Thread TS : {thread_ts}, Brain ID: {brain_id}") + + if not brain_id: + + brains_response = self.make_quivr_api_request("GET", "/brains/") + brains = brains_response.get("brains", []) + # limit to 24 brains + brains = brains[:24] - client.reactions_add( + if not brains: + say("No brains found. Please create a brain first.") + return + + # Create a button for each brain and an 'Any brain' button + brain_buttons = [ + { + "type": "button", + "text": {"type": "plain_text", "text": brain["name"]}, + "action_id": f"brain_{brain['id']}", + } + for brain in brains + ] + ## Action ID should be a null UUID with only zeros + any_brain = [ + { + "type": "button", + "text": {"type": "plain_text", "text": "Any brain"}, + "action_id": "00000000-0000-0000-0000-000000000000", + } + ] + + # Send a message with the buttons + response = self.app.client.chat_postMessage( + channel=body["event"]["channel"], + text="Please select a brain:", + blocks=[ + { + "type": "section", + "text": {"type": "mrkdwn", "text": "Available brains:"}, + }, + { + "type": "actions", + "elements": brain_buttons, + }, + ], + thread_ts=body["event"]["ts"], + ) + logger.info(f"Adding interactive message to thread {thread_ts}") + self.set_interactive_message_id(response["ts"], thread_ts) + + response2 = self.app.client.chat_postMessage( + channel=body["event"]["channel"], + text="Or let me choose for you:", + blocks=[ + { + "type": "actions", + "elements": any_brain, + }, + ], + thread_ts=body["event"]["ts"], + ) + + logger.info(f"Adding interactive message to thread {thread_ts}") + self.set_interactive_message_id(response2["ts"], thread_ts) + + else: + self.ask_question( + body, brain_id, body["event"]["ts"], question=body["event"]["text"] + ) + self.app.client.reactions_remove( channel=body["event"]["channel"], name="brain", timestamp=body["event"]["ts"], ) + self.app.client.reactions_add( + channel=body["event"]["channel"], + name="white_check_mark", + timestamp=body["event"]["ts"], + ) - brains_response = self.make_quivr_api_request("GET", "/brains/") - brains = brains_response.get("brains", []) + def ask_question(self, body, brain_id, thread_ts, question=None): + logger.info(body) - if not brains: - say("No brains found. Please create a brain first.") - return + # Extract the question from the body + text = body["event"]["text"] + bot_id = f"<@{self.app.client.auth_test()['user_id']}>" + question = text.replace(bot_id, "").strip() - thread_ts = body["event"].get("thread_ts") chat_id = self.get_chat_id(thread_ts) if not chat_id: chat_data = {"name": "Slack Chat"} chat_response = self.make_quivr_api_request("POST", "/chat", data=chat_data) chat_id = chat_response["chat_id"] - self.set_chat_id(body["event"]["ts"], chat_id) + self.set_chat_id(thread_ts, chat_id) - logger.info(body["event"]["text"]) - mention = f'<@{body["event"]["user"]}>' - question = body["event"]["text"].replace(mention, "").strip() - question_data = {"question": question} + logger.debug(question) + # If question is not provided, get it from the database + params = {} + if brain_id == "00000000-0000-0000-0000-000000000000": + brain_id = None + params = {"brain_id": brain_id} + question_data = { + "question": question, + } question_response = self.make_quivr_api_request( - "POST", f"/chat/{chat_id}/question", data=question_data + "POST", f"/chat/{chat_id}/question", data=question_data, params=params ) - client.reactions_add( - channel=body["event"]["channel"], - name="white_check_mark", - timestamp=body["event"]["ts"], + logger.debug(question_response) + if "assistant" in question_response: + self.app.client.chat_postMessage( + channel=body["event"]["channel"], + text=question_response["assistant"], + thread_ts=thread_ts, + ) + else: + self.app.client.chat_postMessage( + channel=body["event"]["channel"], + text="Sorry, I couldn't find an answer.", + thread_ts=thread_ts, + ) + + def handle_iteractive_request(self, payload): + + brain_id = None + action_id = payload["actions"][0]["action_id"] + logger.info(f"Payload: {payload}") + if action_id.startswith("brain_"): + brain_id = action_id.split("_")[1] + + thread_ts = payload["container"]["thread_ts"] + logger.info(f"Thread TS: {thread_ts}") + logger.info(f"Interactive Brain ID: {brain_id}") + self.set_brain_id(thread_ts, brain_id) + + self.app.client.chat_postMessage( + channel=payload["channel"]["id"], + text=f"Asking Question to {payload['actions'][0]['text']['text']}", + thread_ts=thread_ts, ) - logger.debug(question_response) + chat_id = self.get_chat_id(thread_ts) + if not chat_id: + logger.info("Creating chat") + chat_data = {"name": "Slack Chat"} + chat_response = self.make_quivr_api_request("POST", "/chat", data=chat_data) + chat_id = chat_response["chat_id"] + self.set_chat_id(thread_ts, chat_id) + + question = self.get_question(thread_ts) + params = {} + if brain_id == "00000000-0000-0000-0000-000000000000": + brain_id = None + params = {"brain_id": brain_id} + question_data = { + "question": question, + } + question_response = self.make_quivr_api_request( + "POST", f"/chat/{chat_id}/question", data=question_data, params=params + ) + + logger.debug(question_response["assistant"]) + logger.debug(f"Brain ID: {question_response.get('brain_id')}") if "assistant" in question_response: - say(question_response["assistant"], thread_ts=body["event"]["ts"]) + interactive_messages = self.get_all_interactive_message_ids_for_thread( + thread_ts + ) + logger.info(f"Interactive messages: {interactive_messages}") + for message in interactive_messages: + logger.info(f"Deleting message: {message}") + self.app.client.chat_delete( + channel=payload["channel"]["id"], + ts=message[0], + ) + self.app.client.chat_postMessage( + channel=payload["channel"]["id"], + text=question_response["assistant"], + thread_ts=thread_ts, + ) + self.set_brain_id(thread_ts, question_response.get("brain_id")) else: - say("Sorry, I couldn't find an answer.", thread_ts=body["event"]["ts"]) + self.app.client.chat_postMessage( + channel=payload["channel"]["id"], + text="Sorry, I couldn't find an answer.", + thread_ts=thread_ts, + ) + + return BoltResponse(status=200) # FastAPI app setup @@ -180,7 +444,19 @@ async def endpoint(req: Request): return await app_handler.handle(req) +@api.post("/slack/interactive") +async def interactive(req: Request, ack: Ack = Depends(Ack)): + logger.info("Received interactive request") + ## URL decoding + body = await req.body() + body_decoded = urllib.parse.unquote(body.decode("utf-8")) + payload = json.loads(body_decoded.split("payload=")[1]) + + ack() # Acknowledge the request + return slack_chat_app.handle_iteractive_request(payload) + + if __name__ == "__main__": import uvicorn - uvicorn.run("main:api", host="0.0.0.0", port=1234, log_level="debug", reload=True) + uvicorn.run("main:api", host="0.0.0.0", port=1234, log_level="warning", reload=True)