-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
133 lines (107 loc) · 4.4 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import argparse
from langchain.vectorstores.chroma import Chroma
from langchain.prompts import ChatPromptTemplate
from langchain_community.llms.ollama import Ollama
from langchain_community.embeddings import OllamaEmbeddings
import os
import shutil
from langchain.document_loaders.pdf import PyPDFDirectoryLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.schema.document import Document
CHROMA_PATH = "chroma"
DATA_PATH = "data"
PROMPT_TEMPLATE = """
Answer the question based only on the following context:
{context}
---
Now, Answer the question based on the above context: {question}
"""
def get_embedding_function():
embeddings = OllamaEmbeddings(model="nomic-embed-text")
return embeddings
class PopulateDatabase():
def __init__(self, reset = False, chunk_size = 1000, chunk_overlap = 80) -> None:
self.reset = reset
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
def clear_database(self):
if os.path.exists(CHROMA_PATH):
shutil.rmtree(CHROMA_PATH)
def load_documents(self):
document_loader = PyPDFDirectoryLoader(DATA_PATH)
return document_loader.load()
def calculate_chunk_ids(self, chunks : list[Document]):
last_page_id = None
curr_chunk_index = 0
for chunk in chunks:
source = chunk.metadata.get("source")
page = chunk.metadata.get("page")
curr_page_id = f"{source}:{page}"
# If page ID is the same as last one, increment index
if curr_page_id == last_page_id:
curr_chunk_index+=1
else:
curr_chunk_index = 0
# Calc. Chunk ID
chunk_id = f"{curr_page_id}:{curr_chunk_index}"
last_page_id = curr_page_id
chunk.metadata["id"] = chunk_id
return chunks
def add_to_chromadb(self, chunks: list[Document]):
# Load the existing database
db = Chroma(
persist_directory=CHROMA_PATH, embedding_function=get_embedding_function()
)
# Calculate IDs for pages
chunk_ids = self.calculate_chunk_ids(chunks)
# Add/Update the documents to the DB
existing_items = db.get(include=[])
existing_ids = set(existing_items["ids"])
print(f"Number of existing documents in the DB : {len(existing_ids)}")
# Only add new documents to the DB
new_chunks = []
for chunk in chunk_ids:
if chunk.metadata["id"] not in existing_ids:
new_chunks.append(chunk)
if len(new_chunks):
print(f"Adding new documents : {len(new_chunks)}")
new_chunk_ids = [chunk.metadata["id"] for chunk in new_chunks]
db.add_documents(new_chunks, ids = new_chunk_ids)
db.persist()
else:
print(f"No new documents to add right now!")
def split_documents(self, documents: list[Document]):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size = self.chunk_size,
chunk_overlap = self.chunk_overlap,
length_function = len,
is_separator_regex=False
)
return text_splitter.split_documents(documents)
def execute(self):
if self.reset:
print("Clearing the database!")
self.clear_database()
# Create/update the data store
documents = self.load_documents()
chunks = self.split_documents(documents)
self.add_to_chromadb(chunks)
def query_rag(query_text: str):
# Prepare the DB
print("Preparing the DB")
embedding_function = get_embedding_function()
db = Chroma(
persist_directory=CHROMA_PATH, embedding_function=embedding_function
)
# Search the DB
print("Searching the DB")
results = db.similarity_search_with_score(query_text, k = 3)
context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results])
prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
prompt = prompt_template.format(context = context_text, question = query_text)
model = Ollama(model="llama3")
response_text = model.invoke(prompt)
sources = [doc.metadata.get("id", None) for doc, _score in results]
formatted_response = f"Response: {response_text} \n\nSources: {sources}"
print(formatted_response)
return response_text, sources