Skip to content

Commit

Permalink
Fix Agent Slowness (#3979)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuhongsun96 authored Feb 13, 2025
1 parent c6434db commit 1a7aca0
Show file tree
Hide file tree
Showing 14 changed files with 87 additions and 28 deletions.
1 change: 1 addition & 0 deletions backend/ee/onyx/server/query_and_chat/query_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def handle_search_request(
user=user,
llm=llm,
fast_llm=fast_llm,
skip_query_analysis=False,
db_session=db_session,
bypass_acl=False,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_context_manager
from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
Expand Down Expand Up @@ -67,9 +68,12 @@ def retrieve_documents(
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(
query=query_to_retrieve,
force_no_rerank=True,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=not state.base_search,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
Expand Down
14 changes: 11 additions & 3 deletions backend/onyx/agents/agent_search/shared_graph_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
)
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.tools.force import ForceUseTool
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_constructor import SearchToolConfig
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
Expand Down Expand Up @@ -218,7 +219,10 @@ def get_test_config(
using_tool_calling_llm=using_tool_calling_llm,
)

chat_session_id = os.environ.get("ONYX_AS_CHAT_SESSION_ID")
chat_session_id = (
os.environ.get("ONYX_AS_CHAT_SESSION_ID")
or "00000000-0000-0000-0000-000000000000"
)
assert (
chat_session_id is not None
), "ONYX_AS_CHAT_SESSION_ID must be set for backend tests"
Expand Down Expand Up @@ -341,8 +345,12 @@ def retrieve_search_docs(
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(
query=question,
force_no_rerank=True,
alternate_db_session=db_session,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=db_session,
retrieved_sections_callback=None,
skip_query_analysis=False,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
Expand Down
9 changes: 9 additions & 0 deletions backend/onyx/context/search/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
user: User | None,
llm: LLM,
fast_llm: LLM,
skip_query_analysis: bool,
db_session: Session,
bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION
retrieval_metrics_callback: (
Expand All @@ -67,6 +68,7 @@ def __init__(
self.user = user
self.llm = llm
self.fast_llm = fast_llm
self.skip_query_analysis = skip_query_analysis
self.db_session = db_session
self.bypass_acl = bypass_acl
self.retrieval_metrics_callback = retrieval_metrics_callback
Expand Down Expand Up @@ -108,6 +110,7 @@ def _run_preprocessing(self) -> None:
search_request=self.search_request,
user=self.user,
llm=self.llm,
skip_query_analysis=self.skip_query_analysis,
db_session=self.db_session,
bypass_acl=self.bypass_acl,
)
Expand Down Expand Up @@ -162,6 +165,12 @@ def _get_sections(self) -> list[InferenceSection]:
that have a corresponding chunk.
This step should be fast for any document index implementation.
Current implementation timing is approximately broken down in timing as:
- 200 ms to get the embedding of the query
- 15 ms to get chunks from the document index
- possibly more to get additional surrounding chunks
- possibly more for query expansion (multilingual)
"""
if self._retrieved_sections is not None:
return self._retrieved_sections
Expand Down
8 changes: 4 additions & 4 deletions backend/onyx/context/search/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ def retrieval_preprocessing(
search_request: SearchRequest,
user: User | None,
llm: LLM,
skip_query_analysis: bool,
db_session: Session,
bypass_acl: bool = False,
skip_query_analysis: bool = False,
base_recency_decay: float = BASE_RECENCY_DECAY,
favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER,
base_recency_decay: float = BASE_RECENCY_DECAY,
bypass_acl: bool = False,
) -> SearchQuery:
"""Logic is as follows:
Any global disables apply first
Expand Down Expand Up @@ -146,7 +146,7 @@ def retrieval_preprocessing(
is_keyword, extracted_keywords = (
parallel_results[run_query_analysis.result_id]
if run_query_analysis
else (None, None)
else (False, None)
)

all_query_terms = query.split()
Expand Down
2 changes: 1 addition & 1 deletion backend/onyx/natural_language_processing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _check_tokenizer_cache(

if not tokenizer:
logger.info(
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
f"Falling back to default embedding model tokenizer: {DOCUMENT_ENCODER_MODEL}"
)
tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)

Expand Down
1 change: 1 addition & 0 deletions backend/onyx/server/gpts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def gpt_search(
user=None,
llm=llm,
fast_llm=fast_llm,
skip_query_analysis=True,
db_session=db_session,
).reranked_sections

Expand Down
2 changes: 1 addition & 1 deletion backend/onyx/tools/base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def build_user_message_for_non_tool_calling_llm(
""".strip()


class BaseTool(Tool):
class BaseTool(Tool[None]):
def build_next_prompt(
self,
prompt_builder: "AnswerPromptBuilder",
Expand Down
13 changes: 13 additions & 0 deletions backend/onyx/tools/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from collections.abc import Callable
from typing import Any
from uuid import UUID

from pydantic import BaseModel
from pydantic import model_validator
from sqlalchemy.orm import Session

from onyx.context.search.enums import SearchType
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceSection


class ToolResponse(BaseModel):
Expand Down Expand Up @@ -57,5 +60,15 @@ class SearchQueryInfo(BaseModel):
recency_bias_multiplier: float


class SearchToolOverrideKwargs(BaseModel):
force_no_rerank: bool
alternate_db_session: Session | None
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None
skip_query_analysis: bool

class Config:
arbitrary_types_allowed = True


CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID"
MESSAGE_ID_PLACEHOLDER = "MESSAGE_ID"
11 changes: 9 additions & 2 deletions backend/onyx/tools/tool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import abc
from collections.abc import Generator
from typing import Any
from typing import Generic
from typing import TYPE_CHECKING
from typing import TypeVar

from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
Expand All @@ -14,7 +16,10 @@
from onyx.tools.models import ToolResponse


class Tool(abc.ABC):
OVERRIDE_T = TypeVar("OVERRIDE_T")


class Tool(abc.ABC, Generic[OVERRIDE_T]):
@property
@abc.abstractmethod
def name(self) -> str:
Expand Down Expand Up @@ -57,7 +62,9 @@ def get_args_for_non_tool_calling_llm(
"""Actual execution of the tool"""

@abc.abstractmethod
def run(self, **kwargs: Any) -> Generator["ToolResponse", None, None]:
def run(
self, override_kwargs: OVERRIDE_T | None = None, **llm_kwargs: Any
) -> Generator["ToolResponse", None, None]:
raise NotImplementedError

@abc.abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class CustomToolCallSummary(BaseModel):
tool_result: Any # The response data


# override_kwargs is not supported for custom tools
class CustomTool(BaseTool):
def __init__(
self,
Expand Down Expand Up @@ -235,7 +236,9 @@ def _parse_csv(self, csv_text: str) -> List[Dict[str, Any]]:

"""Actual execution of the tool"""

def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
def run(
self, override_kwargs: dict[str, Any] | None = None, **kwargs: Any
) -> Generator[ToolResponse, None, None]:
request_body = kwargs.get(REQUEST_BODY)

path_params = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ class ImageShape(str, Enum):
LANDSCAPE = "landscape"


class ImageGenerationTool(Tool):
# override_kwargs is not supported for image generation tools
class ImageGenerationTool(Tool[None]):
_NAME = "run_image_generation"
_DESCRIPTION = "Generate an image from a prompt."
_DISPLAY_NAME = "Image Generation"
Expand Down Expand Up @@ -255,7 +256,9 @@ def _generate_image(
"An error occurred during image generation. Please try again later."
)

def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
def run(
self, override_kwargs: None = None, **kwargs: str
) -> Generator[ToolResponse, None, None]:
prompt = cast(str, kwargs["prompt"])
shape = ImageShape(kwargs.get("shape", ImageShape.SQUARE))
format = self.output_format
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def internet_search_response_to_search_docs(
]


class InternetSearchTool(Tool):
# override_kwargs is not supported for internet search tools
class InternetSearchTool(Tool[None]):
_NAME = "run_internet_search"
_DISPLAY_NAME = "Internet Search"
_DESCRIPTION = "Perform an internet search for up-to-date information."
Expand Down Expand Up @@ -242,7 +243,9 @@ def _perform_search(self, query: str) -> InternetSearchResponse:
],
)

def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
def run(
self, override_kwargs: None = None, **kwargs: str
) -> Generator[ToolResponse, None, None]:
query = cast(str, kwargs["internet_search_query"])

results = self._perform_search(query)
Expand Down
25 changes: 16 additions & 9 deletions backend/onyx/tools/tool_implementations/search/search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from onyx.secondary_llm_flows.query_expansion import history_based_query_rephrase
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
Expand Down Expand Up @@ -77,7 +78,7 @@ class SearchResponseSummary(SearchQueryInfo):
"""


class SearchTool(Tool):
class SearchTool(Tool[SearchToolOverrideKwargs]):
_NAME = "run_search"
_DISPLAY_NAME = "Search Tool"
_DESCRIPTION = SEARCH_TOOL_DESCRIPTION
Expand Down Expand Up @@ -275,14 +276,19 @@ def _build_response_for_specified_sections(

yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)

def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
query = cast(str, kwargs["query"])
force_no_rerank = cast(bool, kwargs.get("force_no_rerank", False))
alternate_db_session = cast(Session, kwargs.get("alternate_db_session", None))
retrieved_sections_callback = cast(
Callable[[list[InferenceSection]], None],
kwargs.get("retrieved_sections_callback"),
)
def run(
self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any
) -> Generator[ToolResponse, None, None]:
query = cast(str, llm_kwargs["query"])
force_no_rerank = False
alternate_db_session = None
retrieved_sections_callback = None
skip_query_analysis = False
if override_kwargs:
force_no_rerank = override_kwargs.force_no_rerank
alternate_db_session = override_kwargs.alternate_db_session
retrieved_sections_callback = override_kwargs.retrieved_sections_callback
skip_query_analysis = override_kwargs.skip_query_analysis

if self.selected_sections:
yield from self._build_response_for_specified_sections(query)
Expand Down Expand Up @@ -324,6 +330,7 @@ def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
user=self.user,
llm=self.llm,
fast_llm=self.fast_llm,
skip_query_analysis=skip_query_analysis,
bypass_acl=self.bypass_acl,
db_session=alternate_db_session or self.db_session,
prompt_config=self.prompt_config,
Expand Down

0 comments on commit 1a7aca0

Please sign in to comment.