-
Notifications
You must be signed in to change notification settings - Fork 8.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* update welcome svg * fix loading chatglm3 (#1937) * update welcome svg * update welcome message * fix loading chatglm3 --------- Co-authored-by: binary-husky <qingxu.fu@outlook.com> Co-authored-by: binary-husky <96192199+binary-husky@users.noreply.github.com> * begin rag project with llama index * rag version one * rag beta release * add social worker (proto) * fix llamaindex version --------- Co-authored-by: moetayuko <loli@yuko.moe>
- Loading branch information
1 parent
16f4fd6
commit dd66ca2
Showing
19 changed files
with
1,103 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
RAG忘了触发保存了! |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg | ||
from crazy_functions.crazy_utils import input_clipping | ||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive | ||
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker | ||
|
||
RAG_WORKER_REGISTER = {} | ||
|
||
MAX_HISTORY_ROUND = 5 | ||
MAX_CONTEXT_TOKEN_LIMIT = 4096 | ||
REMEMBER_PREVIEW = 1000 | ||
|
||
@CatchException | ||
def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request): | ||
|
||
# 1. we retrieve rag worker from global context | ||
user_name = chatbot.get_user() | ||
if user_name in RAG_WORKER_REGISTER: | ||
rag_worker = RAG_WORKER_REGISTER[user_name] | ||
else: | ||
rag_worker = RAG_WORKER_REGISTER[user_name] = LlamaIndexRagWorker( | ||
user_name, | ||
llm_kwargs, | ||
checkpoint_dir=get_log_folder(user_name, plugin_name='experimental_rag'), | ||
auto_load_checkpoint=True) | ||
|
||
chatbot.append([txt, '正在召回知识 ...']) | ||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 | ||
|
||
# 2. clip history to reduce token consumption | ||
# 2-1. reduce chat round | ||
txt_origin = txt | ||
|
||
if len(history) > MAX_HISTORY_ROUND * 2: | ||
history = history[-(MAX_HISTORY_ROUND * 2):] | ||
txt_clip, history, flags = input_clipping(txt, history, max_token_limit=MAX_CONTEXT_TOKEN_LIMIT, return_clip_flags=True) | ||
input_is_clipped_flag = (flags["original_input_len"] != flags["clipped_input_len"]) | ||
|
||
# 2-2. if input is clipped, add input to vector store before retrieve | ||
if input_is_clipped_flag: | ||
yield from update_ui_lastest_msg('检测到长输入, 正在向量化 ...', chatbot, history, delay=0) # 刷新界面 | ||
# save input to vector store | ||
rag_worker.add_text_to_vector_store(txt_origin) | ||
yield from update_ui_lastest_msg('向量化完成 ...', chatbot, history, delay=0) # 刷新界面 | ||
if len(txt_origin) > REMEMBER_PREVIEW: | ||
HALF = REMEMBER_PREVIEW//2 | ||
i_say_to_remember = txt[:HALF] + f" ...\n...(省略{len(txt_origin)-REMEMBER_PREVIEW}字)...\n... " + txt[-HALF:] | ||
if (flags["original_input_len"] - flags["clipped_input_len"]) > HALF: | ||
txt_clip = txt_clip + f" ...\n...(省略{len(txt_origin)-len(txt_clip)-HALF}字)...\n... " + txt[-HALF:] | ||
else: | ||
pass | ||
i_say = txt_clip | ||
else: | ||
i_say_to_remember = i_say = txt_clip | ||
else: | ||
i_say_to_remember = i_say = txt_clip | ||
|
||
# 3. we search vector store and build prompts | ||
nodes = rag_worker.retrieve_from_store_with_query(i_say) | ||
prompt = rag_worker.build_prompt(query=i_say, nodes=nodes) | ||
|
||
# 4. it is time to query llms | ||
if len(chatbot) != 0: chatbot.pop(-1) # pop temp chat, because we are going to add them again inside `request_gpt_model_in_new_thread_with_ui_alive` | ||
model_say = yield from request_gpt_model_in_new_thread_with_ui_alive( | ||
inputs=prompt, inputs_show_user=i_say, | ||
llm_kwargs=llm_kwargs, chatbot=chatbot, history=history, | ||
sys_prompt=system_prompt, | ||
retry_times_at_unknown_error=0 | ||
) | ||
|
||
# 5. remember what has been asked / answered | ||
yield from update_ui_lastest_msg(model_say + '</br></br>' + '对话记忆中, 请稍等 ...', chatbot, history, delay=0.5) # 刷新界面 | ||
rag_worker.remember_qa(i_say_to_remember, model_say) | ||
history.extend([i_say, model_say]) | ||
|
||
yield from update_ui_lastest_msg(model_say, chatbot, history, delay=0) # 刷新界面 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg | ||
from crazy_functions.crazy_utils import input_clipping | ||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive | ||
import pickle, os | ||
|
||
SOCIAL_NETWOK_WORKER_REGISTER = {} | ||
|
||
class SocialNetwork(): | ||
def __init__(self): | ||
self.people = [] | ||
|
||
class SocialNetworkWorker(): | ||
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None: | ||
self.user_name = user_name | ||
self.checkpoint_dir = checkpoint_dir | ||
if auto_load_checkpoint: | ||
self.social_network = self.load_from_checkpoint(checkpoint_dir) | ||
else: | ||
self.social_network = SocialNetwork() | ||
|
||
def does_checkpoint_exist(self, checkpoint_dir=None): | ||
import os, glob | ||
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir | ||
if not os.path.exists(checkpoint_dir): return False | ||
if len(glob.glob(os.path.join(checkpoint_dir, "social_network.pkl"))) == 0: return False | ||
return True | ||
|
||
def save_to_checkpoint(self, checkpoint_dir=None): | ||
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir | ||
with open(os.path.join(checkpoint_dir, 'social_network.pkl'), "wb+") as f: | ||
pickle.dump(self.social_network, f) | ||
return | ||
|
||
def load_from_checkpoint(self, checkpoint_dir=None): | ||
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir | ||
if self.does_checkpoint_exist(checkpoint_dir=checkpoint_dir): | ||
with open(os.path.join(checkpoint_dir, 'social_network.pkl'), "rb") as f: | ||
social_network = pickle.load(f) | ||
return social_network | ||
else: | ||
return SocialNetwork() | ||
|
||
|
||
@CatchException | ||
def I人助手(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request, num_day=5): | ||
|
||
# 1. we retrieve worker from global context | ||
user_name = chatbot.get_user() | ||
checkpoint_dir=get_log_folder(user_name, plugin_name='experimental_rag') | ||
if user_name in SOCIAL_NETWOK_WORKER_REGISTER: | ||
social_network_worker = SOCIAL_NETWOK_WORKER_REGISTER[user_name] | ||
else: | ||
social_network_worker = SOCIAL_NETWOK_WORKER_REGISTER[user_name] = SocialNetworkWorker( | ||
user_name, | ||
llm_kwargs, | ||
checkpoint_dir=checkpoint_dir, | ||
auto_load_checkpoint=True | ||
) | ||
|
||
# 2. save | ||
social_network_worker.social_network.people.append("张三") | ||
social_network_worker.save_to_checkpoint(checkpoint_dir) | ||
chatbot.append(["good", "work"]) | ||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import llama_index | ||
from llama_index.core import Document | ||
from llama_index.core.schema import TextNode | ||
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel | ||
from shared_utils.connect_void_terminal import get_chat_default_kwargs | ||
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader | ||
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex | ||
from llama_index.core.ingestion import run_transformations | ||
from llama_index.core import PromptTemplate | ||
from llama_index.core.response_synthesizers import TreeSummarize | ||
|
||
DEFAULT_QUERY_GENERATION_PROMPT = """\ | ||
Now, you have context information as below: | ||
--------------------- | ||
{context_str} | ||
--------------------- | ||
Answer the user request below (use the context information if necessary, otherwise you can ignore them): | ||
--------------------- | ||
{query_str} | ||
""" | ||
|
||
QUESTION_ANSWER_RECORD = """\ | ||
{{ | ||
"type": "This is a previous conversation with the user", | ||
"question": "{question}", | ||
"answer": "{answer}", | ||
}} | ||
""" | ||
|
||
|
||
class SaveLoad(): | ||
|
||
def does_checkpoint_exist(self, checkpoint_dir=None): | ||
import os, glob | ||
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir | ||
if not os.path.exists(checkpoint_dir): return False | ||
if len(glob.glob(os.path.join(checkpoint_dir, "*.json"))) == 0: return False | ||
return True | ||
|
||
def save_to_checkpoint(self, checkpoint_dir=None): | ||
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir | ||
self.vs_index.storage_context.persist(persist_dir=checkpoint_dir) | ||
|
||
def load_from_checkpoint(self, checkpoint_dir=None): | ||
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir | ||
if self.does_checkpoint_exist(checkpoint_dir=checkpoint_dir): | ||
print('loading checkpoint from disk') | ||
from llama_index.core import StorageContext, load_index_from_storage | ||
storage_context = StorageContext.from_defaults(persist_dir=checkpoint_dir) | ||
self.vs_index = load_index_from_storage(storage_context, embed_model=self.embed_model) | ||
return self.vs_index | ||
else: | ||
return self.create_new_vs() | ||
|
||
def create_new_vs(self): | ||
return GptacVectorStoreIndex.default_vector_store(embed_model=self.embed_model) | ||
|
||
|
||
class LlamaIndexRagWorker(SaveLoad): | ||
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None: | ||
self.debug_mode = True | ||
self.embed_model = OpenAiEmbeddingModel(llm_kwargs) | ||
self.user_name = user_name | ||
self.checkpoint_dir = checkpoint_dir | ||
if auto_load_checkpoint: | ||
self.vs_index = self.load_from_checkpoint(checkpoint_dir) | ||
else: | ||
self.vs_index = self.create_new_vs() | ||
|
||
def assign_embedding_model(self): | ||
pass | ||
|
||
def inspect_vector_store(self): | ||
# This function is for debugging | ||
self.vs_index.storage_context.index_store.to_dict() | ||
docstore = self.vs_index.storage_context.docstore.docs | ||
vector_store_preview = "\n".join([ f"{_id} | {tn.text}" for _id, tn in docstore.items() ]) | ||
print('\n++ --------inspect_vector_store begin--------') | ||
print(vector_store_preview) | ||
print('oo --------inspect_vector_store end--------') | ||
return vector_store_preview | ||
|
||
def add_documents_to_vector_store(self, document_list): | ||
documents = [Document(text=t) for t in document_list] | ||
documents_nodes = run_transformations( | ||
documents, # type: ignore | ||
self.vs_index._transformations, | ||
show_progress=True | ||
) | ||
self.vs_index.insert_nodes(documents_nodes) | ||
if self.debug_mode: self.inspect_vector_store() | ||
|
||
def add_text_to_vector_store(self, text): | ||
node = TextNode(text=text) | ||
documents_nodes = run_transformations( | ||
[node], | ||
self.vs_index._transformations, | ||
show_progress=True | ||
) | ||
self.vs_index.insert_nodes(documents_nodes) | ||
if self.debug_mode: self.inspect_vector_store() | ||
|
||
def remember_qa(self, question, answer): | ||
formatted_str = QUESTION_ANSWER_RECORD.format(question=question, answer=answer) | ||
self.add_text_to_vector_store(formatted_str) | ||
|
||
def retrieve_from_store_with_query(self, query): | ||
if self.debug_mode: self.inspect_vector_store() | ||
retriever = self.vs_index.as_retriever() | ||
return retriever.retrieve(query) | ||
|
||
def build_prompt(self, query, nodes): | ||
context_str = self.generate_node_array_preview(nodes) | ||
return DEFAULT_QUERY_GENERATION_PROMPT.format(context_str=context_str, query_str=query) | ||
|
||
def generate_node_array_preview(self, nodes): | ||
buf = "\n".join(([f"(No.{i+1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)])) | ||
if self.debug_mode: print(buf) | ||
return buf | ||
|
||
|
||
|
Oops, something went wrong.