Skip to content

Commit

Permalink
update config
Browse files Browse the repository at this point in the history
  • Loading branch information
ElmiraGhorbani committed Jul 22, 2023
1 parent bdfc783 commit d62b95f
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 34 deletions.
8 changes: 7 additions & 1 deletion chatgpt_long_term_memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from chatgpt_long_term_memory.openai_engine import (OpenAIChatBot,
OpenAIChatConfig,
retry_on_openai_errors)
from chatgpt_long_term_memory.llama_index_helpers import DocIndexer, Retrievers
from chatgpt_long_term_memory.llama_index_helpers.config import (
IndexConfig, RetrieversConfig)
Expand All @@ -10,5 +13,8 @@
"IndexConfig",
"RetrieversConfig",
"ChatMemory",
"ChatMemoryConfig"
"ChatMemoryConfig",
"OpenAIChatConfig",
"OpenAIChatBot",
"retry_on_openai_errors"
]
22 changes: 11 additions & 11 deletions chatgpt_long_term_memory/llama_index_helpers/config.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from pydantic import BaseModel
from pydantic import BaseModel, Field


class IndexConfig(BaseModel):
root_path: str = ""
knowledge_base: bool = True
model_name: str = "gpt-3.5-turbo"
temperature: int = 0
context_window: int = 4096
num_outputs: int = 700
max_chunk_overlap: float = 0.5
chunk_size_limit: int = 600
root_path: str = Field(default="/home/elmira/projects/0-NAVIS/long-term-memory-chatbot")
knowledge_base: bool = Field(default=True)
model_name: str = Field(default="gpt-3.5-turbo")
temperature: int = Field(default=0)
context_window: int = Field(default=4096)
num_outputs: int = Field(default=700)
max_chunk_overlap: float = Field(default=0.5)
chunk_size_limit: int = Field(default=600)


class RetrieversConfig(BaseModel):
top_k: int = 3
similarity_threshold: float = 0.7
top_k: int = Field(default=7)
similarity_threshold: float = Field(default=0.7)
35 changes: 21 additions & 14 deletions chatgpt_long_term_memory/llama_index_helpers/index_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,28 @@ class DocIndexer:
"""

def __init__(
self, config: IndexConfig):
self.knowledge_base = config.knowledge_base
self, doc_config: IndexConfig, **kw):
super().__init__(**kw)
self.config = doc_config

# Initialize the OpenAI language model
self.llm = OpenAI(model=config.model_name,
temperature=config.temperature)
self.llm = OpenAI(model=self.config.model_name,
temperature=self.config.temperature)
service_context = ServiceContext.from_defaults(llm=self.llm)
set_global_service_context(service_context)

self.root_path = config.root_path
self.data_path = f'{self.root_path}/resources/data'
assert os.path.exists(
self.data_path), f"Path '{self.data_path}' does not exist!"
self.root_path = self.config.root_path
if self.config.knowledge_base:
self.data_path = f'{self.root_path}/resources/data'
assert os.path.exists(
self.data_path), f"Path '{self.data_path}' does not exist!"

# Define prompt helper for the index
self.prompt_helper = PromptHelper(
config.context_window,
config.num_outputs,
config.max_chunk_overlap,
chunk_size_limit=config.chunk_size_limit
self.config.context_window,
self.config.num_outputs,
self.config.max_chunk_overlap,
chunk_size_limit=self.config.chunk_size_limit
)

def load_documents(self, retrieved_documents):
Expand All @@ -70,7 +72,7 @@ def load_documents(self, retrieved_documents):
sorted_data = sorted(retrieved_documents,
key=lambda x: list(x.keys())[0], reverse=True)
doc = list(sorted_data[0].values())[0]
doc = f"question: {doc['user_query']}, answer: {doc['bot_response']}"
doc = f"USER: {doc['user_query']}, ANSWER: {doc['bot_response']}"
return Document(text=doc, doc_id=f"doc_id_{str(uuid.uuid4())}")

def construct_index_general(self, user_id, path, mode):
Expand All @@ -86,15 +88,18 @@ def construct_index_general(self, user_id, path, mode):
VectorStoreIndex: The constructed index.
"""
# If the user's storage directory does not exist, create it.
print(self.config.knowledge_base)
print(path)
if not os.path.exists(path):
os.makedirs(path)

# Load data from the directory for the user's personal knowledge base
if mode == 'kb':
if self.knowledge_base:
if self.config.knowledge_base:
documents = SimpleDirectoryReader(self.data_path).load_data()
else:
# User doesn't have any personal knowledge base
print("kb f")
documents = []

# Create a Document structure for the user's input query and chatbot response
Expand Down Expand Up @@ -125,9 +130,11 @@ def load_index(self, user_id):
index_path = f'{self.root_path}/storages/storage_{user_id}'
status = os.path.exists(f'{index_path}/vector_store.json')
if status:
print("index_path: if ", index_path)
index = load_index_from_storage(
StorageContext.from_defaults(persist_dir=index_path))
else:
print("index_path: else", index_path)
index = self.construct_index_general(
user_id, index_path, mode="kb")
return index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ class Retrievers:
"""

def __init__(self, config: RetrieversConfig):
self.top_k = config.top_k
self.similarity_threshold = config.similarity_threshold
def __init__(self, retrieve_config: RetrieversConfig, **kw):
super().__init__(**kw)
self.config = retrieve_config
self.top_k = self.config.top_k
self.similarity_threshold = self.config.similarity_threshold

def query(self, index, question):
"""
Expand Down
6 changes: 4 additions & 2 deletions chatgpt_long_term_memory/memory/chat_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ class ChatMemory:
"""

def __init__(self, config: ChatMemoryConfig):
def __init__(self, memory_config: ChatMemoryConfig, **kw):
super().__init__(**kw)
self.config = memory_config
self.redis_db = RedisManager(
host=config.redis_host, port=config.redis_port)
host=self.config.redis_host, port=self.config.redis_port)

def get(self, user_id):
"""
Expand Down
6 changes: 3 additions & 3 deletions chatgpt_long_term_memory/memory/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pydantic import BaseModel
from pydantic import BaseModel, Field


class ChatMemoryConfig(BaseModel):
redis_host: str = "172.16.0.2"
redis_port: int = 6379
redis_host: str = Field(default="172.16.0.2")
redis_port: int = Field(default=6379)

0 comments on commit d62b95f

Please sign in to comment.