Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Agent Slowness #3979

Merged
merged 5 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/ee/onyx/server/query_and_chat/query_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def handle_search_request(
user=user,
llm=llm,
fast_llm=fast_llm,
skip_query_analysis=False,
db_session=db_session,
bypass_acl=False,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_context_manager
from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
Expand Down Expand Up @@ -67,9 +68,12 @@ def retrieve_documents(
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(
query=query_to_retrieve,
force_no_rerank=True,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=not state.base_search,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
Expand Down
14 changes: 11 additions & 3 deletions backend/onyx/agents/agent_search/shared_graph_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
)
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.tools.force import ForceUseTool
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_constructor import SearchToolConfig
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
Expand Down Expand Up @@ -218,7 +219,10 @@ def get_test_config(
using_tool_calling_llm=using_tool_calling_llm,
)

chat_session_id = os.environ.get("ONYX_AS_CHAT_SESSION_ID")
chat_session_id = (
os.environ.get("ONYX_AS_CHAT_SESSION_ID")
or "00000000-0000-0000-0000-000000000000"
)
assert (
chat_session_id is not None
), "ONYX_AS_CHAT_SESSION_ID must be set for backend tests"
Expand Down Expand Up @@ -341,8 +345,12 @@ def retrieve_search_docs(
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(
query=question,
force_no_rerank=True,
alternate_db_session=db_session,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=db_session,
retrieved_sections_callback=None,
skip_query_analysis=False,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
Expand Down
9 changes: 9 additions & 0 deletions backend/onyx/context/search/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
user: User | None,
llm: LLM,
fast_llm: LLM,
skip_query_analysis: bool,
db_session: Session,
bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION
retrieval_metrics_callback: (
Expand All @@ -67,6 +68,7 @@ def __init__(
self.user = user
self.llm = llm
self.fast_llm = fast_llm
self.skip_query_analysis = skip_query_analysis
self.db_session = db_session
self.bypass_acl = bypass_acl
self.retrieval_metrics_callback = retrieval_metrics_callback
Expand Down Expand Up @@ -108,6 +110,7 @@ def _run_preprocessing(self) -> None:
search_request=self.search_request,
user=self.user,
llm=self.llm,
skip_query_analysis=self.skip_query_analysis,
db_session=self.db_session,
bypass_acl=self.bypass_acl,
)
Expand Down Expand Up @@ -162,6 +165,12 @@ def _get_sections(self) -> list[InferenceSection]:
that have a corresponding chunk.

This step should be fast for any document index implementation.

Current implementation timing is approximately broken down in timing as:
- 200 ms to get the embedding of the query
- 15 ms to get chunks from the document index
- possibly more to get additional surrounding chunks
- possibly more for query expansion (multilingual)
"""
if self._retrieved_sections is not None:
return self._retrieved_sections
Expand Down
8 changes: 4 additions & 4 deletions backend/onyx/context/search/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ def retrieval_preprocessing(
search_request: SearchRequest,
user: User | None,
llm: LLM,
skip_query_analysis: bool,
db_session: Session,
bypass_acl: bool = False,
skip_query_analysis: bool = False,
base_recency_decay: float = BASE_RECENCY_DECAY,
favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER,
base_recency_decay: float = BASE_RECENCY_DECAY,
bypass_acl: bool = False,
) -> SearchQuery:
"""Logic is as follows:
Any global disables apply first
Expand Down Expand Up @@ -146,7 +146,7 @@ def retrieval_preprocessing(
is_keyword, extracted_keywords = (
parallel_results[run_query_analysis.result_id]
if run_query_analysis
else (None, None)
else (False, None)
)

all_query_terms = query.split()
Expand Down
2 changes: 1 addition & 1 deletion backend/onyx/natural_language_processing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _check_tokenizer_cache(

if not tokenizer:
logger.info(
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
f"Falling back to default embedding model tokenizer: {DOCUMENT_ENCODER_MODEL}"
)
tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)

Expand Down
1 change: 1 addition & 0 deletions backend/onyx/server/gpts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def gpt_search(
user=None,
llm=llm,
fast_llm=fast_llm,
skip_query_analysis=True,
db_session=db_session,
).reranked_sections

Expand Down
2 changes: 1 addition & 1 deletion backend/onyx/tools/base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def build_user_message_for_non_tool_calling_llm(
""".strip()


class BaseTool(Tool):
class BaseTool(Tool[None]):
def build_next_prompt(
self,
prompt_builder: "AnswerPromptBuilder",
Expand Down
13 changes: 13 additions & 0 deletions backend/onyx/tools/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from collections.abc import Callable
from typing import Any
from uuid import UUID

from pydantic import BaseModel
from pydantic import model_validator
from sqlalchemy.orm import Session

from onyx.context.search.enums import SearchType
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceSection


class ToolResponse(BaseModel):
Expand Down Expand Up @@ -57,5 +60,15 @@ class SearchQueryInfo(BaseModel):
recency_bias_multiplier: float


class SearchToolOverrideKwargs(BaseModel):
force_no_rerank: bool
alternate_db_session: Session | None
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None
skip_query_analysis: bool

class Config:
arbitrary_types_allowed = True


CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID"
MESSAGE_ID_PLACEHOLDER = "MESSAGE_ID"
11 changes: 9 additions & 2 deletions backend/onyx/tools/tool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import abc
from collections.abc import Generator
from typing import Any
from typing import Generic
from typing import TYPE_CHECKING
from typing import TypeVar

from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
Expand All @@ -14,7 +16,10 @@
from onyx.tools.models import ToolResponse


class Tool(abc.ABC):
OVERRIDE_T = TypeVar("OVERRIDE_T")


class Tool(abc.ABC, Generic[OVERRIDE_T]):
@property
@abc.abstractmethod
def name(self) -> str:
Expand Down Expand Up @@ -57,7 +62,9 @@ def get_args_for_non_tool_calling_llm(
"""Actual execution of the tool"""

@abc.abstractmethod
def run(self, **kwargs: Any) -> Generator["ToolResponse", None, None]:
def run(
self, override_kwargs: OVERRIDE_T | None = None, **llm_kwargs: Any
) -> Generator["ToolResponse", None, None]:
raise NotImplementedError

@abc.abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class CustomToolCallSummary(BaseModel):
tool_result: Any # The response data


# override_kwargs is not supported for custom tools
class CustomTool(BaseTool):
def __init__(
self,
Expand Down Expand Up @@ -235,7 +236,9 @@ def _parse_csv(self, csv_text: str) -> List[Dict[str, Any]]:

"""Actual execution of the tool"""

def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
def run(
self, override_kwargs: dict[str, Any] | None = None, **kwargs: Any
) -> Generator[ToolResponse, None, None]:
request_body = kwargs.get(REQUEST_BODY)

path_params = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ class ImageShape(str, Enum):
LANDSCAPE = "landscape"


class ImageGenerationTool(Tool):
# override_kwargs is not supported for image generation tools
class ImageGenerationTool(Tool[None]):
_NAME = "run_image_generation"
_DESCRIPTION = "Generate an image from a prompt."
_DISPLAY_NAME = "Image Generation"
Expand Down Expand Up @@ -255,7 +256,9 @@ def _generate_image(
"An error occurred during image generation. Please try again later."
)

def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
def run(
self, override_kwargs: None = None, **kwargs: str
) -> Generator[ToolResponse, None, None]:
prompt = cast(str, kwargs["prompt"])
shape = ImageShape(kwargs.get("shape", ImageShape.SQUARE))
format = self.output_format
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def internet_search_response_to_search_docs(
]


class InternetSearchTool(Tool):
# override_kwargs is not supported for internet search tools
class InternetSearchTool(Tool[None]):
_NAME = "run_internet_search"
_DISPLAY_NAME = "Internet Search"
_DESCRIPTION = "Perform an internet search for up-to-date information."
Expand Down Expand Up @@ -242,7 +243,9 @@ def _perform_search(self, query: str) -> InternetSearchResponse:
],
)

def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
def run(
self, override_kwargs: None = None, **kwargs: str
) -> Generator[ToolResponse, None, None]:
query = cast(str, kwargs["internet_search_query"])

results = self._perform_search(query)
Expand Down
25 changes: 16 additions & 9 deletions backend/onyx/tools/tool_implementations/search/search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from onyx.secondary_llm_flows.query_expansion import history_based_query_rephrase
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
Expand Down Expand Up @@ -77,7 +78,7 @@ class SearchResponseSummary(SearchQueryInfo):
"""


class SearchTool(Tool):
class SearchTool(Tool[SearchToolOverrideKwargs]):
_NAME = "run_search"
_DISPLAY_NAME = "Search Tool"
_DESCRIPTION = SEARCH_TOOL_DESCRIPTION
Expand Down Expand Up @@ -275,14 +276,19 @@ def _build_response_for_specified_sections(

yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)

def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
query = cast(str, kwargs["query"])
force_no_rerank = cast(bool, kwargs.get("force_no_rerank", False))
alternate_db_session = cast(Session, kwargs.get("alternate_db_session", None))
retrieved_sections_callback = cast(
Callable[[list[InferenceSection]], None],
kwargs.get("retrieved_sections_callback"),
)
def run(
self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any
) -> Generator[ToolResponse, None, None]:
query = cast(str, llm_kwargs["query"])
force_no_rerank = False
alternate_db_session = None
retrieved_sections_callback = None
skip_query_analysis = False
if override_kwargs:
force_no_rerank = override_kwargs.force_no_rerank
alternate_db_session = override_kwargs.alternate_db_session
retrieved_sections_callback = override_kwargs.retrieved_sections_callback
skip_query_analysis = override_kwargs.skip_query_analysis

if self.selected_sections:
yield from self._build_response_for_specified_sections(query)
Expand Down Expand Up @@ -324,6 +330,7 @@ def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
user=self.user,
llm=self.llm,
fast_llm=self.fast_llm,
skip_query_analysis=skip_query_analysis,
bypass_acl=self.bypass_acl,
db_session=alternate_db_session or self.db_session,
prompt_config=self.prompt_config,
Expand Down
Loading