-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathchat.py
117 lines (94 loc) · 4.33 KB
/
chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import tempfile
import streamlit as st
from langchain.llms import OpenAI
from langchain.llms import LlamaCpp
from langchain.document_loaders import PyPDFLoader
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import ConversationalRetrievalChain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
st.set_page_config(page_title="Chat with PDF", page_icon="🦜")
st.title("🦜 Chat with PDF")
@st.cache_resource(ttl="1h")
def configure_retriever(uploaded_files):
docs = []
temp_dir = tempfile.TemporaryDirectory()
for file in uploaded_files:
temp_filepath = os.path.join(temp_dir.name, file.name)
with open(temp_filepath, "wb") as f:
f.write(file.getvalue())
loader = PyPDFLoader(temp_filepath)
docs.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=20)
splits = text_splitter.split_documents(docs)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectordb = FAISS.from_documents(splits, embedding=embeddings)
retriever = vectordb.as_retriever(search_kwargs={"k": 2})
return retriever
class StreamHandler(BaseCallbackHandler):
def __init__(self, container: st.delta_generator.DeltaGenerator, initial_text: str = ""):
self.container = container
self.text = initial_text
self.run_id_ignore_token = None
def on_llm_start(self, serialized: dict, prompts: list, **kwargs):
# Workaround to prevent showing the rephrased question as output
if prompts[0].startswith("Human"):
self.run_id_ignore_token = kwargs.get("run_id")
def on_llm_new_token(self, token: str, **kwargs) -> None:
if self.run_id_ignore_token == kwargs.get("run_id", False):
return
self.text += token
self.container.markdown(self.text)
class PrintRetrievalHandler(BaseCallbackHandler):
def __init__(self, container):
self.status = container.status("**Context Retrieval**")
def on_retriever_start(self, serialized: dict, query: str, **kwargs):
self.status.write(f"**Question:** {query}")
self.status.update(label=f"**Context Retrieval:** {query}")
def on_retriever_end(self, documents, **kwargs):
for idx, doc in enumerate(documents):
source = os.path.basename(doc.metadata["source"])
self.status.write(f"**Document {idx} from {source}**")
self.status.markdown(doc.page_content)
self.status.update(state="complete")
uploaded_files = st.sidebar.file_uploader(
label="Upload PDF files", type=["pdf"], accept_multiple_files=True
)
if not uploaded_files:
st.info("Please upload PDF documents to continue.")
st.stop()
retriever = configure_retriever(uploaded_files)
msgs = StreamlitChatMessageHistory()
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
# We can use Python LlamaCpp - but this is way slower on my machine than the LM Studio (with llamap.cpp) option
#llm = LlamaCpp(
# model_path="models/TheBloke/dolphin-2.2.1-mistral-7B-GGUF/dolphin-2.2.1-mistral-7b.Q5_K_M.gguf",
# temperature=0.0,
# n_ctx=8192,
# n_gpu_layers=1,
# streaming=True
#)
llm = OpenAI(openai_api_base="http://localhost:1234/v1",
openai_api_key="1234",
temperature=0.0,
streaming=True
)
qa_chain = ConversationalRetrievalChain.from_llm(
llm, chain_type='stuff', retriever=retriever, memory=memory
)
if len(msgs.messages) == 0 or st.sidebar.button("Clear message history"):
msgs.clear()
msgs.add_ai_message("How can I help you?")
avatars = {"human": "user", "ai": "assistant"}
for msg in msgs.messages:
st.chat_message(avatars[msg.type]).write(msg.content)
if user_query := st.chat_input(placeholder="Ask me anything!"):
st.chat_message("user").write(user_query)
with st.chat_message("assistant"):
retrieval_handler = PrintRetrievalHandler(st.container())
stream_handler = StreamHandler(st.empty())
response = qa_chain.run(user_query, callbacks=[retrieval_handler, stream_handler])