Skip to content

Commit

Permalink
fixing chat
Browse files Browse the repository at this point in the history
  • Loading branch information
jerowe committed Jun 17, 2024
1 parent f017a8a commit 816200f
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 1 deletion.
2 changes: 2 additions & 0 deletions aws_bedrock_utilities/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from langchain.agents import Tool
from langchain.chains import RetrievalQA
from langchain_aws import ChatBedrock
import functools
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import CharacterTextSplitter
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -162,6 +163,7 @@ def get_models_args(self, model: str):
logging.warning(f"Model args not found for {model}")
return {}

@functools.cache
def get_llm(self, model_id: str):
args = self.get_models_args(model_id)
llm = ChatBedrock(
Expand Down
2 changes: 1 addition & 1 deletion aws_bedrock_utilities/models/bedrock_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def run_chat(
prompt=prompt,
llm=self.get_llm(model_id=model_id),
verbose=True,
# memory=memory,
memory=ConversationBufferMemory(ai_prefix="AI Assistant"),
)

answer = conversation.invoke(query)
Expand Down
3 changes: 3 additions & 0 deletions aws_bedrock_utilities/models/pgvector_knowledgebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
UnstructuredXMLLoader,
UnstructuredRSTLoader,
UnstructuredExcelLoader,
UnstructuredPowerPointLoader,
DataFrameLoader,
)
import logging
Expand Down Expand Up @@ -167,6 +168,8 @@ def get_loader(filepath: str) -> Dict[str, Any]:
loader = UnstructuredEPubLoader(filepath)
elif file_ext in ["doc", "docx"]:
loader = Docx2txtLoader(filepath)
elif file_ext == 'pptx':
loader = UnstructuredPowerPointLoader(filepath, mode="elements")
elif file_ext in ["xls", "xlsx"]:
loader = UnstructuredExcelLoader(filepath)
elif file_ext == "json":
Expand Down

0 comments on commit 816200f

Please sign in to comment.