From 0fc8fd12bd2e8c6a61ad3cbe7e2e64525667af8b Mon Sep 17 00:00:00 2001 From: Holt Skinner <13262395+holtskinner@users.noreply.github.com> Date: Wed, 8 Nov 2023 20:52:50 -0600 Subject: [PATCH] feat: Vertex AI Search - Add Snippet Retrieval for Non-Advanced Website Data Stores (#13020) https://cloud.google.com/generative-ai-app-builder/docs/snippets#snippets --------- Co-authored-by: Eugene Yurtsev --- .../retrievers/google_vertex_ai_search.py | 44 ++++++++++++------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/libs/langchain/langchain/retrievers/google_vertex_ai_search.py b/libs/langchain/langchain/retrievers/google_vertex_ai_search.py index 53144ffb5f9ef..8dc8cd6d8b9f5 100644 --- a/libs/langchain/langchain/retrievers/google_vertex_ai_search.py +++ b/libs/langchain/langchain/retrievers/google_vertex_ai_search.py @@ -36,7 +36,7 @@ class _BaseGoogleVertexAISearchRetriever(BaseModel): """ Defines the Vertex AI Search data type 0 - Unstructured data 1 - Structured data - 2 - Website data (with Advanced Website Indexing) + 2 - Website data """ @root_validator(pre=True) @@ -154,7 +154,7 @@ def _convert_unstructured_search_response( return documents def _convert_website_search_response( - self, results: Sequence[SearchResult] + self, results: Sequence[SearchResult], chunk_type: str ) -> List[Document]: """Converts a sequence of search results to a list of LangChain documents.""" from google.protobuf.json_format import MessageToDict @@ -173,24 +173,26 @@ def _convert_website_search_response( doc_metadata["id"] = document_dict["id"] doc_metadata["source"] = derived_struct_data.get("link", "") - chunk_type = "extractive_answers" - if chunk_type not in derived_struct_data: continue + text_field = "snippet" if chunk_type == "snippets" else "content" + for chunk in derived_struct_data[chunk_type]: documents.append( Document( - page_content=chunk.get("content", ""), metadata=doc_metadata + page_content=chunk.get(text_field, ""), metadata=doc_metadata ) ) if not documents: - print( - f"No {chunk_type} could be found.\n" - "Make sure that your data store is using Advanced Website Indexing.\n" - "https://cloud.google.com/generative-ai-app-builder/docs/about-advanced-features#advanced-website-indexing" # noqa: E501 - ) + print(f"No {chunk_type} could be found.") + if chunk_type == "extractive_answers": + print( + "Make sure that your data store is using Advanced Website " + "Indexing.\n" + "https://cloud.google.com/generative-ai-app-builder/docs/about-advanced-features#advanced-website-indexing" # noqa: E501 + ) return documents @@ -206,7 +208,7 @@ class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetr filter: Optional[str] = None """Filter expression.""" get_extractive_answers: bool = False - """If True return Extractive Answers, otherwise return Extractive Segments.""" + """If True return Extractive Answers, otherwise return Extractive Segments or Snippets.""" # noqa: E501 max_documents: int = Field(default=5, ge=1, le=100) """The maximum number of documents to return.""" max_extractive_answer_count: int = Field(default=1, ge=1, le=5) @@ -307,12 +309,15 @@ def _create_search_request(self, query: str) -> SearchRequest: content_search_spec = SearchRequest.ContentSearchSpec( extractive_content_spec=SearchRequest.ContentSearchSpec.ExtractiveContentSpec( max_extractive_answer_count=self.max_extractive_answer_count, - ) + ), + snippet_spec=SearchRequest.ContentSearchSpec.SnippetSpec( + return_snippet=True + ), ) else: raise NotImplementedError( "Only data store type 0 (Unstructured), 1 (Structured)," - "or 2 (Website with Advanced Indexing) are supported currently." + "or 2 (Website) are supported currently." + f" Got {self.engine_data_type}" ) @@ -354,11 +359,16 @@ def _get_relevant_documents( elif self.engine_data_type == 1: documents = self._convert_structured_search_response(response.results) elif self.engine_data_type == 2: - documents = self._convert_website_search_response(response.results) + chunk_type = ( + "extractive_answers" if self.get_extractive_answers else "snippets" + ) + documents = self._convert_website_search_response( + response.results, chunk_type + ) else: raise NotImplementedError( "Only data store type 0 (Unstructured), 1 (Structured)," - "or 2 (Website with Advanced Indexing) are supported currently." + "or 2 (Website) are supported currently." + f" Got {self.engine_data_type}" ) @@ -431,7 +441,9 @@ def _get_relevant_documents( response = self._client.converse_conversation(request) if self.engine_data_type == 2: - return self._convert_website_search_response(response.search_results) + return self._convert_website_search_response( + response.search_results, "extractive_answers" + ) return self._convert_unstructured_search_response( response.search_results, "extractive_answers"