-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
133 lines (96 loc) · 3.61 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""Utility functions for the agentic RAG tutorial."""
import copy
import os
import time
import urllib
from flytekit import current_context
def set_openai_api_key():
os.environ["OPENAI_API_KEY"] = current_context().secrets.get(key="openai_api_key")
def split_text_into_lines(text: str, chars_per_line: int) -> str:
assert chars_per_line > 0
assert len(text) > 0
segments = []
while len(text) > 0:
if len(text) <= chars_per_line:
segments.append(text)
break
# Each line is a maximum of `chars_per_line` characters.
# If the ith character is not a space, walk backwards until
# a space is found.
i = chars_per_line
while i > 0:
if text[i] == " ":
segments.append(text[:i])
text = text[i:]
break
else:
i -= 1
return "\n".join(segments)
def generate_data_card(docs: list, head: int = 5, chars_per_line: int = 80) -> str:
_docs = docs[:head]
document_preview_str = ""
for i, doc in enumerate(_docs):
page_content = split_text_into_lines(doc.page_content.replace("```", ""), chars_per_line)
document_preview_str += f"""\n\n---
### 📖 Chunk {i}
**Page metadata:**
{doc.metadata}
**Content:**
```
{page_content}
```
"""
return f"""# 📚 Vector store knowledge base.
This artifact is a vector store of {len(_docs)} document chunks.
## Preview
{document_preview_str}
"""
def get_pubmed_loader(*args, **kwargs):
from langchain_community.document_loaders import PubMedLoader as _PubMedLoader
from langchain_community.utilities.pubmed import PubMedAPIWrapper as _PubMedAPIWrapper
class PubMedAPIWrapper(_PubMedAPIWrapper):
def retrieve_article(self, uid: str, webenv: str) -> dict:
_sleep_time = copy.copy(self.sleep_time)
for _ in range(self.max_retry):
try:
article = super().retrieve_article(uid, webenv)
# reset sleep time
self.sleep_time = _sleep_time
return article
except urllib.error.HTTPError:
time.sleep(self.sleep_time)
class PubMedLoader(_PubMedLoader):
def __init__(self, *args, max_retry: int = 100, sleep_time: float = 0.5, **kwargs):
super().__init__(*args, **kwargs)
self._client = PubMedAPIWrapper( # type: ignore[call-arg]
top_k_results=kwargs["load_max_docs"], # type: ignore[arg-type]
max_retry=max_retry,
sleep_time=sleep_time,
)
return PubMedLoader(*args, **kwargs)
def parse_doc(doc):
# make sure the title is a string
title = doc.metadata["Title"]
if isinstance(title, dict):
title = " ".join(title.values())
doc.metadata["Title"] = title
doc.metadata["source"] = doc.metadata["uid"]
return doc
def get_vector_store_retriever(path: str):
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain.tools.retriever import create_retriever_tool
retriever = Chroma(
collection_name="rag-chroma",
persist_directory=path,
embedding_function=OpenAIEmbeddings(),
).as_retriever()
retriever_tool = create_retriever_tool(
retriever,
"retrieve_pubmed_research",
"Search and return information about pubmed research papers relating "
"to the user query.",
)
return retriever_tool
def set_openai_api_key():
os.environ["OPENAI_API_KEY"] = current_context().secrets.get(key="openai_api_key")