diff --git a/aws_bedrock_utilities/models/base.py b/aws_bedrock_utilities/models/base.py index 69a3a15..48fd2b0 100644 --- a/aws_bedrock_utilities/models/base.py +++ b/aws_bedrock_utilities/models/base.py @@ -48,6 +48,7 @@ from langchain.agents import Tool from langchain.chains import RetrievalQA from langchain_aws import ChatBedrock +import functools from langchain_community.document_loaders import PyPDFLoader from langchain_text_splitters import CharacterTextSplitter from pydantic import BaseModel, Field @@ -162,6 +163,7 @@ def get_models_args(self, model: str): logging.warning(f"Model args not found for {model}") return {} + @functools.cache def get_llm(self, model_id: str): args = self.get_models_args(model_id) llm = ChatBedrock( diff --git a/aws_bedrock_utilities/models/bedrock_chat.py b/aws_bedrock_utilities/models/bedrock_chat.py index 72620c1..fcf050e 100644 --- a/aws_bedrock_utilities/models/bedrock_chat.py +++ b/aws_bedrock_utilities/models/bedrock_chat.py @@ -74,7 +74,7 @@ def run_chat( prompt=prompt, llm=self.get_llm(model_id=model_id), verbose=True, - # memory=memory, + memory=ConversationBufferMemory(ai_prefix="AI Assistant"), ) answer = conversation.invoke(query) diff --git a/aws_bedrock_utilities/models/pgvector_knowledgebase.py b/aws_bedrock_utilities/models/pgvector_knowledgebase.py index 93f71ba..987b0a5 100644 --- a/aws_bedrock_utilities/models/pgvector_knowledgebase.py +++ b/aws_bedrock_utilities/models/pgvector_knowledgebase.py @@ -36,6 +36,7 @@ UnstructuredXMLLoader, UnstructuredRSTLoader, UnstructuredExcelLoader, + UnstructuredPowerPointLoader, DataFrameLoader, ) import logging @@ -167,6 +168,8 @@ def get_loader(filepath: str) -> Dict[str, Any]: loader = UnstructuredEPubLoader(filepath) elif file_ext in ["doc", "docx"]: loader = Docx2txtLoader(filepath) + elif file_ext == 'pptx': + loader = UnstructuredPowerPointLoader(filepath, mode="elements") elif file_ext in ["xls", "xlsx"]: loader = UnstructuredExcelLoader(filepath) elif file_ext == "json":