Skip to content

Commit

Permalink
Merge pull request #36 from weni-ai/feature/ElasticsearchStore
Browse files Browse the repository at this point in the history
Feature/elasticsearch store
  • Loading branch information
AlisoSouza authored Jun 4, 2024
2 parents ec78848 + 1d7a0e0 commit c9e1d65
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
7 changes: 4 additions & 3 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fastapi import FastAPI
from langchain.embeddings import HuggingFaceHubEmbeddings, CohereEmbeddings
from langchain.embeddings.base import Embeddings
from langchain.vectorstores import ElasticVectorSearch, VectorStore
from langchain.vectorstores import ElasticVectorSearch, VectorStore, ElasticsearchStore

from app.handlers import IDocumentHandler
from app.handlers.products import ProductsHandler
Expand Down Expand Up @@ -74,10 +74,11 @@ def __init__(self, config: AppConfig):
self.products_handler = ProductsHandler(self.products_indexer)
self.api.include_router(self.products_handler.router)

self.content_base_vectorstore = ElasticVectorSearch(
elasticsearch_url=config.es_url,
self.content_base_vectorstore = ElasticsearchStore(
es_url=config.es_url,
index_name=config.content_base_index_name,
embedding=self.embeddings,
strategy=ElasticsearchStore.ExactRetrievalStrategy()
)
self.content_base_vectorstore.client = Elasticsearch(
hosts=config.es_url, timeout=int(config.es_timeout)
Expand Down
16 changes: 14 additions & 2 deletions app/store/elasticsearch_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,16 @@ class ContentBaseElasticsearchVectorStoreIndex(ElasticsearchVectorStoreIndex):

def save(self, docs: list[Document]) -> list[str]:
index = os.environ.get("INDEX_CONTENTBASES_NAME", "content_bases")
res = self.vectorstore.from_documents(docs, self.vectorstore.embeddings, index_name=index)
res = self.vectorstore.from_documents(
docs,
self.vectorstore.embeddings,
es_url=os.environ.get("ELASTICSEARCH_URL"),
index_name=index,
bulk_kwargs={
"chunk_size": os.environ.get("DEFAULT_CHUNK_SIZE", 75),
"max_chunk_bytes": 200000000
}
)
return res

def query_search(self, search_filter: dict) -> list[dict]:
Expand Down Expand Up @@ -123,7 +132,10 @@ def search_delete(self, search_filter: dict, scroll_id: str = None) -> tuple[str
return scroll_id, hits

def search(self, search: str, filter=None, threshold=0.1) -> list[Document]:
docs = self.vectorstore.similarity_search_with_score(query=search, k=5, filter=filter)
content_base_uuid = filter.get("content_base_uuid")
q = {"match": {"metadata.content_base_uuid": content_base_uuid}}

docs = self.vectorstore.similarity_search_with_score(query=search, k=5, filter=q)
return [doc[0] for doc in docs if doc[1] > threshold]

def delete(self, ids: list[str] = []) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion app/tests/test_elasticsearch_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_search(self):
search="test", filter=query_filter
)
self.vectorstore.similarity_search_with_score.assert_called_once_with(
query="test", k=5, filter=query_filter
query="test", k=5, filter={'match': {'metadata.content_base_uuid': 'dfff32e7-dce6-40f7-a86e-8f9618887977'}}
)
self.assertEqual(1, len(results))
self.assertEqual(results[0].page_content, "test doc")
Expand Down

0 comments on commit c9e1d65

Please sign in to comment.