-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLlamaDxRAG.py
121 lines (108 loc) · 4.66 KB
/
LlamaDxRAG.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import chromadb
from chromadb.config import Settings
from chromadb.utils import embedding_functions
from llama_stack_client import LlamaStackClient
from llama_stack_client.types.memory_insert_params import Document
from typing import List
# Embedding configuration
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
CHUNK_SIZE_TOKENS = 512
OVERLAP_SIZE_TOKENS = 10
class LlamaDxRAG:
def __init__(self, docs_dir: str, chroma_dir: str, memory_bank_id: str):
self.docs_dir = docs_dir
self.memory_bank_id = memory_bank_id
self.chroma_dir = chroma_dir
self.initialize_agent()
def initialize_agent(self):
# Let's see what providers are available
# Providers determine where and how your data is stored
self.chroma_client = chromadb.PersistentClient(
settings=Settings(
persist_directory=self.chroma_dir
)
)
# Ensure collection exists
collections = self.chroma_client.list_collections()
if any(col.name == self.memory_bank_id for col in collections):
print(f"The collection '{self.memory_bank_id}' already exists.")
else:
print(f"The collection '{self.memory_bank_id}' does not exist. Creating and initializing...")
# Create collection with embedding model
embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=EMBEDDING_MODEL
)
collections = self.chroma_client.create_collection(
name=self.memory_bank_id,
embedding_function=embedding_function
)
self.insert_documents(collections)
# Text chunking function
def chunk_text(self, text: str, chunk_size: int, overlap: int) -> List[str]:
"""
Splits a text into overlapping chunks.
Args:
text (str): The input text to split.
chunk_size (int): Maximum number of tokens in a chunk.
overlap (int): Number of overlapping tokens between consecutive chunks.
Returns:
List[str]: List of text chunks.
"""
words = text.split() # Split by whitespace
chunks = []
for i in range(0, len(words), chunk_size - overlap):
chunk = " ".join(words[i:i + chunk_size])
chunks.append(chunk)
return chunks
def insert_documents(self, collections):
# Load and process documents
documents = []
metadatas = []
ids = []
for filename in os.listdir(self.docs_dir):
if filename.endswith((".txt", ".md")):
file_path = os.path.join(self.docs_dir, filename)
with open(file_path, "r", encoding="utf-8") as file:
content = file.read()
chunks = self.chunk_text(content, CHUNK_SIZE_TOKENS, OVERLAP_SIZE_TOKENS)
for idx, chunk in enumerate(chunks):
chunk_id = f"{filename}_chunk_{idx}"
documents.append(chunk)
metadatas.append({"filename": filename, "chunk": idx})
ids.append(chunk_id)
# Insert documents into collection
if documents:
collections.add(
documents=documents,
metadatas=metadatas,
ids=ids
)
def generate_documents(self, query: str, top_k: int):
"""Helper function to print query results in a readable format
Args:
query (str): The search query to execute
"""
# print(f"\nQuery: {query}")
# print("-" * 50)
collections = self.chroma_client.get_collection(name=self.memory_bank_id)
results = collections.query(
query_texts=[query],
n_results=top_k # Retrieve top k results
)
curated_text = ""
# Iterate over results
for query_idx, (docs, distances) in enumerate(zip(results["documents"], results["distances"])):
#print(f"\nResults for Query {query_idx + 1}")
for i, (doc, score) in enumerate(zip(docs, distances)):
curated_text += f"Document {i+1} (Score: {score:.3f})\n"
curated_text += "=" * 40 + "\n"
curated_text += doc + '\n'
curated_text += "=" * 40 + "\n"
return curated_text
# def main():
# docs_dir = './rag_genetics_small'
# chroma_dir = "./chroma" # Directory to persist Chroma data
# llama_rag = LlamaDxRAG(docs_dir=docs_dir, chroma_dir = chroma_dir)
# query = 'Cystic Fibrosis'
# llama_rag_genetics = llama_rag.generate_documents(query, top_k = 5)