From 1a7aca06b9eee9ef0661b45a56e39bce9a06cbab Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Thu, 13 Feb 2025 15:54:34 -0800 Subject: [PATCH] Fix Agent Slowness (#3979) --- .../server/query_and_chat/query_backend.py | 1 + .../nodes/retrieve_documents.py | 10 +++++--- .../agent_search/shared_graph_utils/utils.py | 14 ++++++++--- backend/onyx/context/search/pipeline.py | 9 +++++++ .../search/preprocessing/preprocessing.py | 8 +++--- .../onyx/natural_language_processing/utils.py | 2 +- backend/onyx/server/gpts/api.py | 1 + backend/onyx/tools/base_tool.py | 2 +- backend/onyx/tools/models.py | 13 ++++++++++ backend/onyx/tools/tool.py | 11 ++++++-- .../custom/custom_tool.py | 5 +++- .../images/image_generation_tool.py | 7 ++++-- .../internet_search/internet_search_tool.py | 7 ++++-- .../search/search_tool.py | 25 ++++++++++++------- 14 files changed, 87 insertions(+), 28 deletions(-) diff --git a/backend/ee/onyx/server/query_and_chat/query_backend.py b/backend/ee/onyx/server/query_and_chat/query_backend.py index 34fc9dbaf3f..2910ac3a76a 100644 --- a/backend/ee/onyx/server/query_and_chat/query_backend.py +++ b/backend/ee/onyx/server/query_and_chat/query_backend.py @@ -83,6 +83,7 @@ def handle_search_request( user=user, llm=llm, fast_llm=fast_llm, + skip_query_analysis=False, db_session=db_session, bypass_acl=False, ) diff --git a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/retrieve_documents.py b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/retrieve_documents.py index 4fe84d0381c..b0347f75eef 100644 --- a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/retrieve_documents.py +++ b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/retrieve_documents.py @@ -23,6 +23,7 @@ from onyx.context.search.models import InferenceSection from onyx.db.engine import get_session_context_manager from onyx.tools.models import SearchQueryInfo +from onyx.tools.models import SearchToolOverrideKwargs from onyx.tools.tool_implementations.search.search_tool import ( SEARCH_RESPONSE_SUMMARY_ID, ) @@ -67,9 +68,12 @@ def retrieve_documents( with get_session_context_manager() as db_session: for tool_response in search_tool.run( query=query_to_retrieve, - force_no_rerank=True, - alternate_db_session=db_session, - retrieved_sections_callback=callback_container.append, + override_kwargs=SearchToolOverrideKwargs( + force_no_rerank=True, + alternate_db_session=db_session, + retrieved_sections_callback=callback_container.append, + skip_query_analysis=not state.base_search, + ), ): # get retrieved docs to send to the rest of the graph if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID: diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py index e4539d50e5e..86c7c0b490e 100644 --- a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py +++ b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py @@ -58,6 +58,7 @@ ) from onyx.prompts.prompt_utils import handle_onyx_date_awareness from onyx.tools.force import ForceUseTool +from onyx.tools.models import SearchToolOverrideKwargs from onyx.tools.tool_constructor import SearchToolConfig from onyx.tools.tool_implementations.search.search_tool import ( SEARCH_RESPONSE_SUMMARY_ID, @@ -218,7 +219,10 @@ def get_test_config( using_tool_calling_llm=using_tool_calling_llm, ) - chat_session_id = os.environ.get("ONYX_AS_CHAT_SESSION_ID") + chat_session_id = ( + os.environ.get("ONYX_AS_CHAT_SESSION_ID") + or "00000000-0000-0000-0000-000000000000" + ) assert ( chat_session_id is not None ), "ONYX_AS_CHAT_SESSION_ID must be set for backend tests" @@ -341,8 +345,12 @@ def retrieve_search_docs( with get_session_context_manager() as db_session: for tool_response in search_tool.run( query=question, - force_no_rerank=True, - alternate_db_session=db_session, + override_kwargs=SearchToolOverrideKwargs( + force_no_rerank=True, + alternate_db_session=db_session, + retrieved_sections_callback=None, + skip_query_analysis=False, + ), ): # get retrieved docs to send to the rest of the graph if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID: diff --git a/backend/onyx/context/search/pipeline.py b/backend/onyx/context/search/pipeline.py index b03c401fe3e..faf7a898892 100644 --- a/backend/onyx/context/search/pipeline.py +++ b/backend/onyx/context/search/pipeline.py @@ -51,6 +51,7 @@ def __init__( user: User | None, llm: LLM, fast_llm: LLM, + skip_query_analysis: bool, db_session: Session, bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION retrieval_metrics_callback: ( @@ -67,6 +68,7 @@ def __init__( self.user = user self.llm = llm self.fast_llm = fast_llm + self.skip_query_analysis = skip_query_analysis self.db_session = db_session self.bypass_acl = bypass_acl self.retrieval_metrics_callback = retrieval_metrics_callback @@ -108,6 +110,7 @@ def _run_preprocessing(self) -> None: search_request=self.search_request, user=self.user, llm=self.llm, + skip_query_analysis=self.skip_query_analysis, db_session=self.db_session, bypass_acl=self.bypass_acl, ) @@ -162,6 +165,12 @@ def _get_sections(self) -> list[InferenceSection]: that have a corresponding chunk. This step should be fast for any document index implementation. + + Current implementation timing is approximately broken down in timing as: + - 200 ms to get the embedding of the query + - 15 ms to get chunks from the document index + - possibly more to get additional surrounding chunks + - possibly more for query expansion (multilingual) """ if self._retrieved_sections is not None: return self._retrieved_sections diff --git a/backend/onyx/context/search/preprocessing/preprocessing.py b/backend/onyx/context/search/preprocessing/preprocessing.py index da228f5f1fb..2e63ed0e39e 100644 --- a/backend/onyx/context/search/preprocessing/preprocessing.py +++ b/backend/onyx/context/search/preprocessing/preprocessing.py @@ -50,11 +50,11 @@ def retrieval_preprocessing( search_request: SearchRequest, user: User | None, llm: LLM, + skip_query_analysis: bool, db_session: Session, - bypass_acl: bool = False, - skip_query_analysis: bool = False, - base_recency_decay: float = BASE_RECENCY_DECAY, favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER, + base_recency_decay: float = BASE_RECENCY_DECAY, + bypass_acl: bool = False, ) -> SearchQuery: """Logic is as follows: Any global disables apply first @@ -146,7 +146,7 @@ def retrieval_preprocessing( is_keyword, extracted_keywords = ( parallel_results[run_query_analysis.result_id] if run_query_analysis - else (None, None) + else (False, None) ) all_query_terms = query.split() diff --git a/backend/onyx/natural_language_processing/utils.py b/backend/onyx/natural_language_processing/utils.py index 7b68b20d8e9..3c4d1392088 100644 --- a/backend/onyx/natural_language_processing/utils.py +++ b/backend/onyx/natural_language_processing/utils.py @@ -99,7 +99,7 @@ def _check_tokenizer_cache( if not tokenizer: logger.info( - f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}" + f"Falling back to default embedding model tokenizer: {DOCUMENT_ENCODER_MODEL}" ) tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL) diff --git a/backend/onyx/server/gpts/api.py b/backend/onyx/server/gpts/api.py index 58796d6199b..ea2aad20b1e 100644 --- a/backend/onyx/server/gpts/api.py +++ b/backend/onyx/server/gpts/api.py @@ -76,6 +76,7 @@ def gpt_search( user=None, llm=llm, fast_llm=fast_llm, + skip_query_analysis=True, db_session=db_session, ).reranked_sections diff --git a/backend/onyx/tools/base_tool.py b/backend/onyx/tools/base_tool.py index 16ec5d92aa0..4b8479b75bc 100644 --- a/backend/onyx/tools/base_tool.py +++ b/backend/onyx/tools/base_tool.py @@ -34,7 +34,7 @@ def build_user_message_for_non_tool_calling_llm( """.strip() -class BaseTool(Tool): +class BaseTool(Tool[None]): def build_next_prompt( self, prompt_builder: "AnswerPromptBuilder", diff --git a/backend/onyx/tools/models.py b/backend/onyx/tools/models.py index a8918b691e4..1e343e74cb3 100644 --- a/backend/onyx/tools/models.py +++ b/backend/onyx/tools/models.py @@ -1,11 +1,14 @@ +from collections.abc import Callable from typing import Any from uuid import UUID from pydantic import BaseModel from pydantic import model_validator +from sqlalchemy.orm import Session from onyx.context.search.enums import SearchType from onyx.context.search.models import IndexFilters +from onyx.context.search.models import InferenceSection class ToolResponse(BaseModel): @@ -57,5 +60,15 @@ class SearchQueryInfo(BaseModel): recency_bias_multiplier: float +class SearchToolOverrideKwargs(BaseModel): + force_no_rerank: bool + alternate_db_session: Session | None + retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None + skip_query_analysis: bool + + class Config: + arbitrary_types_allowed = True + + CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID" MESSAGE_ID_PLACEHOLDER = "MESSAGE_ID" diff --git a/backend/onyx/tools/tool.py b/backend/onyx/tools/tool.py index 4a8ba80996e..2c7f53647f0 100644 --- a/backend/onyx/tools/tool.py +++ b/backend/onyx/tools/tool.py @@ -1,7 +1,9 @@ import abc from collections.abc import Generator from typing import Any +from typing import Generic from typing import TYPE_CHECKING +from typing import TypeVar from onyx.llm.interfaces import LLM from onyx.llm.models import PreviousMessage @@ -14,7 +16,10 @@ from onyx.tools.models import ToolResponse -class Tool(abc.ABC): +OVERRIDE_T = TypeVar("OVERRIDE_T") + + +class Tool(abc.ABC, Generic[OVERRIDE_T]): @property @abc.abstractmethod def name(self) -> str: @@ -57,7 +62,9 @@ def get_args_for_non_tool_calling_llm( """Actual execution of the tool""" @abc.abstractmethod - def run(self, **kwargs: Any) -> Generator["ToolResponse", None, None]: + def run( + self, override_kwargs: OVERRIDE_T | None = None, **llm_kwargs: Any + ) -> Generator["ToolResponse", None, None]: raise NotImplementedError @abc.abstractmethod diff --git a/backend/onyx/tools/tool_implementations/custom/custom_tool.py b/backend/onyx/tools/tool_implementations/custom/custom_tool.py index a235383a71c..932989e44e5 100644 --- a/backend/onyx/tools/tool_implementations/custom/custom_tool.py +++ b/backend/onyx/tools/tool_implementations/custom/custom_tool.py @@ -74,6 +74,7 @@ class CustomToolCallSummary(BaseModel): tool_result: Any # The response data +# override_kwargs is not supported for custom tools class CustomTool(BaseTool): def __init__( self, @@ -235,7 +236,9 @@ def _parse_csv(self, csv_text: str) -> List[Dict[str, Any]]: """Actual execution of the tool""" - def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: + def run( + self, override_kwargs: dict[str, Any] | None = None, **kwargs: Any + ) -> Generator[ToolResponse, None, None]: request_body = kwargs.get(REQUEST_BODY) path_params = {} diff --git a/backend/onyx/tools/tool_implementations/images/image_generation_tool.py b/backend/onyx/tools/tool_implementations/images/image_generation_tool.py index f4e19e1c283..3185b4a001d 100644 --- a/backend/onyx/tools/tool_implementations/images/image_generation_tool.py +++ b/backend/onyx/tools/tool_implementations/images/image_generation_tool.py @@ -79,7 +79,8 @@ class ImageShape(str, Enum): LANDSCAPE = "landscape" -class ImageGenerationTool(Tool): +# override_kwargs is not supported for image generation tools +class ImageGenerationTool(Tool[None]): _NAME = "run_image_generation" _DESCRIPTION = "Generate an image from a prompt." _DISPLAY_NAME = "Image Generation" @@ -255,7 +256,9 @@ def _generate_image( "An error occurred during image generation. Please try again later." ) - def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: + def run( + self, override_kwargs: None = None, **kwargs: str + ) -> Generator[ToolResponse, None, None]: prompt = cast(str, kwargs["prompt"]) shape = ImageShape(kwargs.get("shape", ImageShape.SQUARE)) format = self.output_format diff --git a/backend/onyx/tools/tool_implementations/internet_search/internet_search_tool.py b/backend/onyx/tools/tool_implementations/internet_search/internet_search_tool.py index 474fa2d675f..1c6b3f21cc1 100644 --- a/backend/onyx/tools/tool_implementations/internet_search/internet_search_tool.py +++ b/backend/onyx/tools/tool_implementations/internet_search/internet_search_tool.py @@ -106,7 +106,8 @@ def internet_search_response_to_search_docs( ] -class InternetSearchTool(Tool): +# override_kwargs is not supported for internet search tools +class InternetSearchTool(Tool[None]): _NAME = "run_internet_search" _DISPLAY_NAME = "Internet Search" _DESCRIPTION = "Perform an internet search for up-to-date information." @@ -242,7 +243,9 @@ def _perform_search(self, query: str) -> InternetSearchResponse: ], ) - def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: + def run( + self, override_kwargs: None = None, **kwargs: str + ) -> Generator[ToolResponse, None, None]: query = cast(str, kwargs["internet_search_query"]) results = self._perform_search(query) diff --git a/backend/onyx/tools/tool_implementations/search/search_tool.py b/backend/onyx/tools/tool_implementations/search/search_tool.py index 2666b2014a4..11d147526a0 100644 --- a/backend/onyx/tools/tool_implementations/search/search_tool.py +++ b/backend/onyx/tools/tool_implementations/search/search_tool.py @@ -39,6 +39,7 @@ from onyx.secondary_llm_flows.query_expansion import history_based_query_rephrase from onyx.tools.message import ToolCallSummary from onyx.tools.models import SearchQueryInfo +from onyx.tools.models import SearchToolOverrideKwargs from onyx.tools.models import ToolResponse from onyx.tools.tool import Tool from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict @@ -77,7 +78,7 @@ class SearchResponseSummary(SearchQueryInfo): """ -class SearchTool(Tool): +class SearchTool(Tool[SearchToolOverrideKwargs]): _NAME = "run_search" _DISPLAY_NAME = "Search Tool" _DESCRIPTION = SEARCH_TOOL_DESCRIPTION @@ -275,14 +276,19 @@ def _build_response_for_specified_sections( yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs) - def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: - query = cast(str, kwargs["query"]) - force_no_rerank = cast(bool, kwargs.get("force_no_rerank", False)) - alternate_db_session = cast(Session, kwargs.get("alternate_db_session", None)) - retrieved_sections_callback = cast( - Callable[[list[InferenceSection]], None], - kwargs.get("retrieved_sections_callback"), - ) + def run( + self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any + ) -> Generator[ToolResponse, None, None]: + query = cast(str, llm_kwargs["query"]) + force_no_rerank = False + alternate_db_session = None + retrieved_sections_callback = None + skip_query_analysis = False + if override_kwargs: + force_no_rerank = override_kwargs.force_no_rerank + alternate_db_session = override_kwargs.alternate_db_session + retrieved_sections_callback = override_kwargs.retrieved_sections_callback + skip_query_analysis = override_kwargs.skip_query_analysis if self.selected_sections: yield from self._build_response_for_specified_sections(query) @@ -324,6 +330,7 @@ def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: user=self.user, llm=self.llm, fast_llm=self.fast_llm, + skip_query_analysis=skip_query_analysis, bypass_acl=self.bypass_acl, db_session=alternate_db_session or self.db_session, prompt_config=self.prompt_config,