-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathendpoints.py
69 lines (56 loc) · 1.77 KB
/
endpoints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import logging
from fastapi import FastAPI
from pydantic import BaseModel
from postgres_chat import PostgresChat
import uvicorn
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import PlainTextResponse
import os
logging.basicConfig(level=logging.INFO)
app = FastAPI()
class AskQuestion(BaseModel):
question: str
CONNECTION_STRING = "postgresql://postgres@localhost:5432/imdb" # replace with your connection string
with open("api_key.txt", "r") as f:
os.environ["OPENAI_API_KEY"] = f.read().strip()
# Instantiate the RAGHandler
rag_handler = PostgresChat(
connection_string=CONNECTION_STRING,
table_name='imdb', # replace with your table you want to query
schema='public',
system_prompt_path='dd.txt'
)
@app.post("/ask-question")
def ask_question(question: AskQuestion):
"""
Endpoint to handle user questions. Uses an LLM to decide
whether to perform database queries or respond directly.
"""
rag_handler.add_user_message(question.question)
result = rag_handler.run_conversation()
return {
"response": result["response"],
"executed_queries": result["executed_queries"],
}
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
@app.get("/reinitialize")
def reinitialize():
"""
Reinitializes the RAGHandler object.
"""
rag_handler.reinitialize_messages()
return {"message": "RAGHandler reinitialized."}
@app.get("/show-system-prompt", response_class=PlainTextResponse)
def show_system_prompt():
"""
Shows the system prompt to the user.
"""
return rag_handler.system_prompt
if __name__ == '__main__':
uvicorn.run("endpoints:app", host='localhost', port=4000, reload=True)