Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jerowe committed Jun 3, 2024
1 parent c20956f commit f017a8a
Showing 1 changed file with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions aws_bedrock_utilities/models/pgvector_knowledgebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def get_loader(filepath: str) -> Dict[str, Any]:


class BedrockPGWrapper(BedrockBase):
def __init__(self, **kwargs):
def __init__(self, collection_name: str = "default", **kwargs):
super().__init__(**kwargs)
if "prompt_template" not in kwargs:
self.prompt_template = self.kb_prompt_template
Expand All @@ -192,6 +192,7 @@ def __init__(self, **kwargs):
model_id="amazon.titan-embed-image-v1", client=self.bedrock_client
)
self.s3 = boto3.client("s3")
self.collection_name = collection_name

@property
def connection_string(self):
Expand Down Expand Up @@ -221,7 +222,7 @@ def conn(self):

@property
def cur(self):
conn = self.connection_string()
conn = self.connection_string
return conn.cursor()

@property
Expand All @@ -231,6 +232,17 @@ def embeddings(self):
)
return embeddings

@property
def vectorstore(self):
vectorstore = PGVector(
embeddings=self.embeddings,
collection_name=self.collection_name,
connection=self.connection_string,
use_jsonb=True,
)

return vectorstore

def create_vectorstore(self, collection_name: str):
vectorstore = PGVector(
embeddings=self.embeddings,
Expand All @@ -251,6 +263,7 @@ def run_ingestion_job(
documents=List[Document],
collection_name: str = "default",
):
logging.info("Starting ingestion job")
y = len(documents)
filtered_docs = []
for d in documents:
Expand All @@ -261,13 +274,9 @@ def run_ingestion_job(
ids.append(hashlib.sha256(d.page_content.encode()).hexdigest())

if len(filtered_docs):
vectorstore = self.create_vectorstore(collection_name=collection_name)
# texts = [i.page_content for i in filtered_docs]
# metadatas = [i.metadata for i in filtered_docs]
# logging.info(f"Adding N: {len(filtered_docs)}")
try:
with funcy.print_durations("load psql"):
vectorstore.add_documents(documents=filtered_docs, ids=ids)
self.vectorstore.add_documents(documents=filtered_docs, ids=ids)
except Exception as e:
logging.warning(f"{e}")
# logging.info(f"Complete {x}/{y}")
Expand All @@ -281,7 +290,10 @@ def setup_local_injestion_job(
):
files = glob.glob(glob_pattern)
docs = []
x = 0
total_chunks = math.ceil(len(files) / chunk_size)
for p in partition_all(chunk_size, files):
logging.info(f"Loading x: {x} of {total_chunks}")
for file in p:
doc = self.load_local_file_to_document(file)
docs = docs + doc
Expand Down Expand Up @@ -328,15 +340,14 @@ def run_kb_chat(
collection_name: str,
prompt_template=None,
model_id="anthropic.claude-3-sonnet-20240229-v1:0",
search_kwargs: Optional[Dict[str, Any]] = None,
) -> RAGResults:
if not prompt_template:
prompt_template = self.kb_prompt_template
prompt = PromptTemplate(
template=prompt_template, input_variables=["context", "input"]
)
retriever = self.create_vectorstore(
collection_name=collection_name
).as_retriever()
retriever = self.vectorstore.as_retriever(search_kwargs=search_kwargs)

combine_docs_chain = create_stuff_documents_chain(
llm=self.get_llm(model_id=model_id),
Expand Down

0 comments on commit f017a8a

Please sign in to comment.