Skip to content

Commit

Permalink
PR improvements pass 1
Browse files Browse the repository at this point in the history
  • Loading branch information
evan-danswer committed Feb 13, 2025
1 parent 1be2079 commit 8723d7a
Show file tree
Hide file tree
Showing 12 changed files with 48 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def rerank_documents(
search_request=search_request,
user=graph_config.tooling.search_tool.user, # bit of a hack
llm=graph_config.tooling.fast_llm,
skip_query_analysis=True,
skip_query_analysis=not state.base_search,
db_session=db_session,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_context_manager
from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
Expand Down Expand Up @@ -65,15 +66,14 @@ def retrieve_documents(

# new db session to avoid concurrency issues
with get_session_context_manager() as db_session:
# TODO is there a better way than just using strings?
# At the very least, the strings should be declared in search_tool.py as module level constants
for tool_response in search_tool.run(
query=query_to_retrieve,
override_kwargs={
"force_no_rerank": True,
"alternate_db_session": db_session,
"retrieved_sections_callback": callback_container.append,
},
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=not state.base_search,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
Expand Down
5 changes: 4 additions & 1 deletion backend/onyx/agents/agent_search/shared_graph_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,10 @@ def get_test_config(
using_tool_calling_llm=using_tool_calling_llm,
)

chat_session_id = os.environ.get("ONYX_AS_CHAT_SESSION_ID")
chat_session_id = (
os.environ.get("ONYX_AS_CHAT_SESSION_ID")
or "00000000-0000-0000-0000-000000000000"
)
assert (
chat_session_id is not None
), "ONYX_AS_CHAT_SESSION_ID must be set for backend tests"
Expand Down
2 changes: 1 addition & 1 deletion backend/onyx/context/search/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def retrieval_preprocessing(
is_keyword, extracted_keywords = (
parallel_results[run_query_analysis.result_id]
if run_query_analysis
else (None, None)
else (False, None)
)

all_query_terms = query.split()
Expand Down
2 changes: 1 addition & 1 deletion backend/onyx/tools/base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def build_user_message_for_non_tool_calling_llm(
""".strip()


class BaseTool(Tool):
class BaseTool(Tool[None]):
def build_next_prompt(
self,
prompt_builder: "AnswerPromptBuilder",
Expand Down
10 changes: 10 additions & 0 deletions backend/onyx/tools/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from collections.abc import Callable
from typing import Any
from uuid import UUID

from pydantic import BaseModel
from pydantic import model_validator
from sqlalchemy.orm import Session

from onyx.context.search.enums import SearchType
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceSection


class ToolResponse(BaseModel):
Expand Down Expand Up @@ -57,5 +60,12 @@ class SearchQueryInfo(BaseModel):
recency_bias_multiplier: float


class SearchToolOverrideKwargs(BaseModel):
force_no_rerank: bool
alternate_db_session: Session | None
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None
skip_query_analysis: bool


CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID"
MESSAGE_ID_PLACEHOLDER = "MESSAGE_ID"
9 changes: 7 additions & 2 deletions backend/onyx/tools/tool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import abc
from collections.abc import Generator
from typing import Any
from typing import Generic
from typing import TYPE_CHECKING
from typing import TypeVar

from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
Expand All @@ -14,7 +16,10 @@
from onyx.tools.models import ToolResponse


class Tool(abc.ABC):
OVERRIDE_T = TypeVar("OVERRIDE_T")


class Tool(abc.ABC, Generic[OVERRIDE_T]):
@property
@abc.abstractmethod
def name(self) -> str:
Expand Down Expand Up @@ -58,7 +63,7 @@ def get_args_for_non_tool_calling_llm(

@abc.abstractmethod
def run(
self, override_kwargs: dict[str, Any] | None = None, **llm_kwargs: Any
self, override_kwargs: OVERRIDE_T | None = None, **llm_kwargs: Any
) -> Generator["ToolResponse", None, None]:
raise NotImplementedError

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class CustomToolCallSummary(BaseModel):
tool_result: Any # The response data


# override_kwargs is not supported for custom tools
class CustomTool(BaseTool):
def __init__(
self,
Expand Down Expand Up @@ -238,10 +239,6 @@ def _parse_csv(self, csv_text: str) -> List[Dict[str, Any]]:
def run(
self, override_kwargs: dict[str, Any] | None = None, **kwargs: Any
) -> Generator[ToolResponse, None, None]:
assert (
override_kwargs is None
) # override_kwargs is not supported for custom tools

request_body = kwargs.get(REQUEST_BODY)

path_params = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ class ImageShape(str, Enum):
LANDSCAPE = "landscape"


class ImageGenerationTool(Tool):
# override_kwargs is not supported for image generation tools
class ImageGenerationTool(Tool[None]):
_NAME = "run_image_generation"
_DESCRIPTION = "Generate an image from a prompt."
_DISPLAY_NAME = "Image Generation"
Expand Down Expand Up @@ -256,12 +257,8 @@ def _generate_image(
)

def run(
self, override_kwargs: dict[str, Any] | None = None, **kwargs: str
self, override_kwargs: None = None, **kwargs: str
) -> Generator[ToolResponse, None, None]:
assert (
override_kwargs is None
) # override_kwargs is not supported for image generation tools

prompt = cast(str, kwargs["prompt"])
shape = ImageShape(kwargs.get("shape", ImageShape.SQUARE))
format = self.output_format
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def internet_search_response_to_search_docs(
]


class InternetSearchTool(Tool):
# override_kwargs is not supported for internet search tools
class InternetSearchTool(Tool[None]):
_NAME = "run_internet_search"
_DISPLAY_NAME = "Internet Search"
_DESCRIPTION = "Perform an internet search for up-to-date information."
Expand Down Expand Up @@ -243,12 +244,8 @@ def _perform_search(self, query: str) -> InternetSearchResponse:
)

def run(
self, override_kwargs: dict[str, Any] | None = None, **kwargs: str
self, override_kwargs: None = None, **kwargs: str
) -> Generator[ToolResponse, None, None]:
assert (
override_kwargs is None
) # override_kwargs is not supported for internet search tools

query = cast(str, kwargs["internet_search_query"])

results = self._perform_search(query)
Expand Down
20 changes: 6 additions & 14 deletions backend/onyx/tools/tool_implementations/search/search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from onyx.secondary_llm_flows.query_expansion import history_based_query_rephrase
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
Expand Down Expand Up @@ -276,27 +277,18 @@ def _build_response_for_specified_sections(
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)

def run(
self, override_kwargs: dict[str, Any] | None = None, **llm_kwargs: Any
self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any
) -> Generator[ToolResponse, None, None]:
query = cast(str, llm_kwargs["query"])
force_no_rerank = False
alternate_db_session = None
retrieved_sections_callback = None
skip_query_analysis = False
if override_kwargs:
force_no_rerank = cast(bool, override_kwargs.get("force_no_rerank", False))
alternate_db_session = cast(
Session, override_kwargs.get("alternate_db_session")
)
retrieved_sections_callback = cast(
Callable[[list[InferenceSection]], None],
override_kwargs.get("retrieved_sections_callback"),
)
# TODO the main flow (user provided query) should pass through this
# The other ones (expanded queries) should not do query analysis, they're all "semantic"
skip_query_analysis = cast(
bool, override_kwargs.get("skip_query_analysis", False)
)
force_no_rerank = override_kwargs.force_no_rerank
alternate_db_session = override_kwargs.alternate_db_session
retrieved_sections_callback = override_kwargs.retrieved_sections_callback
skip_query_analysis = override_kwargs.skip_query_analysis

if self.selected_sections:
yield from self._build_response_for_specified_sections(query)
Expand Down
5 changes: 4 additions & 1 deletion backend/onyx/utils/threadpool_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ def run_functions_in_parallel(
Executes a list of FunctionCalls in parallel and stores the results in a dictionary where the keys
are the result_id of the FunctionCall and the values are the results of the call.
"""
results = {}
results: dict[str, Any] = {}

if len(function_calls) == 0:
return results

with ThreadPoolExecutor(max_workers=len(function_calls)) as executor:
future_to_id = {
Expand Down

0 comments on commit 8723d7a

Please sign in to comment.