-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_localGPT.py
62 lines (49 loc) · 2.16 KB
/
run_localGPT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from langchain.chains import RetrievalQA
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.llms import HuggingFacePipeline
from constants import CHROMA_SETTINGS, PERSIST_DIRECTORY
from transformers import LlamaTokenizer, LlamaForCausalLM, pipeline
from langchain.chains.question_answering import load_qa_chain
from langchain import HuggingFaceHub
import click
from constants import CHROMA_SETTINGS
import os
os.environ["HUGGINGFACEHUB_API_TOKEN"] = "Your_HF_KEY"
from langchain.embeddings import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings()
@click.command()
@click.option('--device_type', default='gpu', help='device to run on, select gpu or cpu')
def main(device_type, ):
# load the instructorEmbeddings
if device_type in ['cpu', 'CPU']:
device='cpu'
else:
device='cuda'
print(f"Running on: {device}")
embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl",
model_kwargs={"device": device})
# load the vectorstore
db = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
retriever = db.as_retriever()
# Prepare the LLM
# callbacks = [StreamingStdOutCallbackHandler()]
# load the Google Flan XXL LLM for generating Natural Language responses.
llm=HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature":1, "max_length":512})
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
# Interactive questions and answers
while True:
query = input("\nEnter a query: ")
if query == "exit":
break
# Get the answer from the chain
res = qa(query)
answer, docs = res['result'], res['source_documents']
# Print the result
print("\n\n> Question:")
print(query)
print("\n> Answer:")
print(answer)
if __name__ == "__main__":
main()