-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvectorize_student_example.py
195 lines (159 loc) · 6.82 KB
/
vectorize_student_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import os
import random
import chromadb
from chromadb import Collection
from powerschool_adapter.powerschool import PowerSchool
import requests
import embedding
from dotenv import load_dotenv
from faker import Faker
import colorama
from colorama import Fore, Back
load_dotenv()
fake = Faker()
colorama.init(autoreset=True)
POWERSCHOOL_SERVER_ADDRESS = os.getenv("POWERSCHOOL_SERVER_ADDRESS")
POWERSCHOOL_CLIENT_ID = os.getenv("POWERSCHOOL_CLIENT_ID")
POWERSCHOOL_CLIENT_SECRET = os.getenv("POWERSCHOOL_CLIENT_SECRET")
CHROMA_DB_HOST = os.getenv("CHROMA_DB_HOST")
CHROMA_DB_PORT = int(os.getenv("CHROMA_DB_PORT"))
EMBEDDING_API_URL = os.getenv("EMBEDDING_API_URL")
chroma_client = chromadb.HttpClient(host=CHROMA_DB_HOST, port=CHROMA_DB_PORT)
def vectorize_students(collection: Collection) -> None:
if (os.getenv('POWERSCHOOL_SERVER_ADDRESS') is None or
os.getenv('POWERSCHOOL_CLIENT_ID') is None or
os.getenv('POWERSCHOOL_CLIENT_SECRET') is None):
print("PowerSchool environment variables were not set.")
exit()
powerschool = PowerSchool(
server_address=POWERSCHOOL_SERVER_ADDRESS,
client_id=POWERSCHOOL_CLIENT_ID,
client_secret=POWERSCHOOL_CLIENT_SECRET
)
response = powerschool.table('students').projection(["DCID", "STUDENT_NUMBER", "LAST_NAME", "FIRST_NAME", "LASTFIRST", "GENDER", "GRADE_LEVEL", "entrydate", "DISTRICTENTRYDATE", "EXITDATE"]).set_method(
"GET").send().squash_table_response()
students = response.to_list()
for student in students:
# Generate a fake 4-digit student number
student_number = student['student_number'] if student['student_number'] else fake.random_int(min=1000, max=9999)
# Enrollment status
enroll_status = "enrolled"
# Generate tuition and capital fees for the year 2025
tuition_fee_2025 = f"{fake.random_int(min=100000, max=300000)} RMB"
capital_fee_2025 = f"{fake.random_int(min=10000, max=30000)} RMB"
# Generate tuition and capital fees for the year 2024
tuition_fee_2024 = f"{fake.random_int(min=100000, max=300000)} RMB"
capital_fee_2024 = f"{fake.random_int(min=10000, max=30000)} RMB"
# Generate student name and gender
"""
student_name = (
fake.name_female() if student["gender"] == "F" else fake.name_male().replace(" ", ", ")
)
"""
student_name = student["lastfirst"]
gender = "Female" if student["gender"] == "F" else "Male"
# Generate parent names
father_name = fake.name_male()
mother_name = fake.name_female()
# Create the student profile
student_profile = {
"name": student_name,
"grade": student["grade_level"],
"student_number": student_number,
"gender": gender,
"enroll_status": enroll_status,
"district_entry_date": student["districtentrydate"],
"entry_date": student["entrydate"],
"exit_date": student["exitdate"],
"parents": {"father": father_name, "mother": mother_name},
}
# Create the billing profile
billing_profile = {
"2025": {"tuition": tuition_fee_2025, "capital_fee": capital_fee_2025},
"2024": {"tuition": tuition_fee_2024, "capital_fee": capital_fee_2024},
}
# Generate checked-out school properties
school_devices = ["Macbook", "iPad", "Lego Kit", "Calculator", "Headphones", "Camera"]
school_books = ["Math Book", "Science Book", "History Book", "English Book", "Chinese Book", "Art Book"]
selected_devices = random.sample(school_devices, 2)
selected_books = random.sample(school_books, 2)
checkout_properties = [
{"type": device,
"checkout_date": fake.date_between(start_date="-1y", end_date="today")}
for device in selected_devices
] + [
{"type": book, "checkout_date": fake.date_between(start_date="-1y", end_date="today")}
for book in selected_books
]
# Create the vaccination profile
vaccination_profile = {
"HepB": {
"1st": fake.date_between(start_date="-20y", end_date="-15y"),
"2nd": fake.date_between(start_date="-15y", end_date="-10y"),
"3rd": fake.date_between(start_date="-10y", end_date="-5y"),
},
"DTaP/Tdap": {
"DTaP_series_completed": "pre-age 7",
"Tdap_booster": fake.date_between(start_date="-5y", end_date="today"),
},
}
# Combine the profiles into a single document
combined_profile = f"{student_profile} {billing_profile} {checkout_properties} {vaccination_profile}"
# Add the combined document to the collection
collection.add(
documents=[combined_profile],
metadatas=[{"source": "student and billing profile"}],
embeddings=[embedding.create_embeddings(combined_profile)],
uris=[f"student_{student_number}"],
ids=[student['dcid']]
)
def query(collection: Collection, query_text):
"""
Queries the Chroma collection using embeddings generated by the Ollama API.
Args:
collection (Collection): The Chroma collection to query.
query_text (str): The text query to embed and search in the collection.
Returns:
dict: The query results from Chroma.
"""
payload = {
"model": "nomic-embed-text", # Using the same model as in create_embeddings
"prompt": query_text # Use the query text directly
}
response = requests.post(EMBEDDING_API_URL, json=payload)
if response.status_code == 200:
query_embedding = response.json().get("embedding")
if not query_embedding:
raise ValueError("Embedding not found in API response.")
# Query the Chroma collection with the embedding
results = collection.query(
query_embeddings=[query_embedding], # Use the embedding for the query
n_results=2 # Adjust the number of results as needed
)
return results
else:
raise Exception(f"Failed to get embeddings: {response.status_code}, {response.text}")
def main():
print(Back.BLACK + Fore.LIGHTYELLOW_EX + f"Chroma Heartbeat: {chroma_client.heartbeat()}")
print(Back.BLACK + Fore.LIGHTCYAN_EX + f"Chroma Version: {chroma_client.get_version()}")
print(Back.BLACK + Fore.LIGHTYELLOW_EX + f"Chroma Collections: {chroma_client.list_collections()}")
# Check if "students" collection exists before attempting to delete it
if "students" in chroma_client.list_collections():
chroma_client.delete_collection("students")
print(Back.BLACK + Fore.LIGHTYELLOW_EX + f"Students collection has been deleted.")
else:
print(Back.BLACK + Fore.LIGHTYELLOW_EX + f"Students collection does not exist.")
# Create or get a collection
collection_name = "students"
collection = chroma_client.get_or_create_collection(name=collection_name)
print(Back.BLACK + Fore.LIGHTGREEN_EX + f"Students collection has been created or retrieved.")
vectorize_students(collection)
all_data = collection.get(
limit=1
)
print(Back.BLACK + Fore.LIGHTMAGENTA_EX + f"Retrieving the first 2 records as examples from Chroma DB: {all_data}")
print(Back.BLACK + Fore.LIGHTRED_EX + f"Querying the database and searching for a record that is not from the examples:")
results = query(collection, "Who is Chong, Hannah?")
print(Back.BLACK + Fore.LIGHTGREEN_EX + f"Results: {results}")
if __name__ == "__main__":
main()