-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
138 lines (107 loc) · 4.23 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import json
import numpy as np
from openai import OpenAI
import os
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Initialize OpenAI-compatible clients
embedding_client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
)
# Use OpenRouter if configured, otherwise fall back to OpenAI
completion_client = OpenAI(
api_key=os.getenv("OPENROUTER_API_KEY", os.getenv("OPENAI_API_KEY")),
base_url=os.getenv(
"OPENROUTER_BASE_URL",
os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1"),
),
)
def load_content():
"""Load content from content.txt, where each line is a separate text entry"""
with open("content.txt", "r") as f:
return [line.strip() for line in f if line.strip()]
def create_embeddings(texts):
"""Create embeddings for given texts using the configured embedding model"""
embeddings = []
model = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
for text in texts:
response = embedding_client.embeddings.create(model=model, input=text)
embeddings.append(response.data[0].embedding)
return embeddings
def save_embeddings(texts, embeddings):
"""Save texts and their embeddings to embedding.json"""
data = {"texts": texts, "embeddings": embeddings}
with open("embedding.json", "w") as f:
json.dump(data, f)
def load_embeddings():
"""Load embeddings from embedding.json"""
with open("embedding.json", "r") as f:
return json.load(f)
def find_most_similar(query_embedding, embeddings, texts, top_k=3):
"""Find most similar texts based on cosine similarity"""
# Convert embeddings to numpy array for efficient computation
embeddings_array = np.array(embeddings)
query_array = np.array(query_embedding)
# Calculate cosine similarity
similarities = np.dot(embeddings_array, query_array) / (
np.linalg.norm(embeddings_array, axis=1) * np.linalg.norm(query_array)
)
# Get indices of top k similar texts
top_indices = np.argsort(similarities)[-top_k:]
return [texts[i] for i in top_indices]
def answer_question(question):
"""Answer a question using embeddings and the configured completion model"""
# Create embedding for the question
model = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
question_embedding_response = embedding_client.embeddings.create(
model=model, input=question
)
question_embedding = question_embedding_response.data[0].embedding
# Load stored embeddings
stored_data = load_embeddings()
# Find most relevant context
relevant_texts = find_most_similar(
question_embedding, stored_data["embeddings"], stored_data["texts"]
)
# Create context from relevant texts
context = "\n\n".join(relevant_texts)
# Generate answer using the configured model
completion_model = os.getenv("COMPLETION_MODEL", "gpt-4o-mini")
response = completion_client.chat.completions.create(
model=completion_model,
messages=[
{
"role": "system",
"content": "You are a helpful assistant. Answer the question based on the provided context. If the answer cannot be found in the context, say so.",
},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {question}"},
],
)
return response.choices[0].message.content
def main():
import os.path
# Check if embeddings already exist
if os.path.exists("embedding.json"):
print("Loading existing embeddings...")
stored_data = load_embeddings()
texts = stored_data["texts"]
embeddings = stored_data["embeddings"]
else:
# Load content and create embeddings
texts = load_content()
print("Creating embeddings...")
embeddings = create_embeddings(texts)
# Save embeddings
print("Saving embeddings...")
save_embeddings(texts, embeddings)
# Example usage of question answering
while True:
question = input("\nEnter your question (or 'quit' to exit): ")
if question.lower() == "quit":
break
print("\nFinding answer...")
answer = answer_question(question)
print(f"\nAnswer: {answer}")
if __name__ == "__main__":
main()