Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upgrade: add web search! #40

Merged
merged 8 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions 1_🏠_Home.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
"To build a new agent, please make sure that 'Create a new agent' is selected.",
icon="ℹ️",
)
if "metaphor_key" in st.secrets:
st.info("**NOTE**: The ability to add web search is enabled.")


add_sidebar()
Expand Down
153 changes: 129 additions & 24 deletions agent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from constants import AGENT_CACHE_DIR
import shutil

from llama_index.callbacks import CallbackManager
from callback_manager import StreamlitFunctionsCallbackHandler


def _resolve_llm(llm_str: str) -> LLM:
"""Resolve LLM."""
Expand Down Expand Up @@ -153,9 +156,25 @@ def load_agent(
"""Load agent."""
extra_kwargs = extra_kwargs or {}
if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
# TODO: use default msg handler
# TODO: separate this from agent_utils.py...
def _msg_handler(msg: str) -> None:
"""Message handler."""
st.info(msg)
st.session_state.agent_messages.append(
{"role": "assistant", "content": msg, "msg_type": "info"}
)

# add streamlit callbacks (to inject events)
handler = StreamlitFunctionsCallbackHandler(_msg_handler)
callback_manager = CallbackManager([handler])
# get OpenAI Agent
agent: BaseChatEngine = OpenAIAgent.from_tools(
tools=tools, llm=llm, system_prompt=system_prompt, **kwargs
tools=tools,
llm=llm,
system_prompt=system_prompt,
**kwargs,
callback_manager=callback_manager,
)
else:
if "vector_index" not in extra_kwargs:
Expand Down Expand Up @@ -189,8 +208,12 @@ def load_meta_agent(
extra_kwargs = extra_kwargs or {}
if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
# get OpenAI Agent

agent: BaseAgent = OpenAIAgent.from_tools(
tools=tools, llm=llm, system_prompt=system_prompt, **kwargs
tools=tools,
llm=llm,
system_prompt=system_prompt,
**kwargs,
)
else:
agent = ReActAgent.from_tools(
Expand Down Expand Up @@ -285,6 +308,66 @@ def construct_agent(
return agent, extra_info


def get_web_agent_tool() -> QueryEngineTool:
"""Get web agent tool.

Wrap with our load and search tool spec.

"""
from llama_hub.tools.metaphor.base import MetaphorToolSpec

# TODO: set metaphor API key
metaphor_tool = MetaphorToolSpec(
api_key=st.secrets.metaphor_key,
)
metaphor_tool_list = metaphor_tool.to_tool_list()

# TODO: LoadAndSearch doesn't work yet
# The search_and_retrieve_documents tool is the third in the tool list,
# as seen above
# wrapped_retrieve = LoadAndSearchToolSpec.from_defaults(
# metaphor_tool_list[2],
# )

# NOTE: requires openai right now
# We don't give the Agent our unwrapped retrieve document tools
# instead passing the wrapped tools
web_agent = OpenAIAgent.from_tools(
# [*wrapped_retrieve.to_tool_list(), metaphor_tool_list[4]],
metaphor_tool_list,
llm=BUILDER_LLM,
verbose=True,
)

# return agent as a tool
# TODO: tune description
web_agent_tool = QueryEngineTool.from_defaults(
web_agent,
name="web_agent",
description="""
This agent can answer questions by searching the web. \
Use this tool if the answer is ONLY likely to be found by searching \
the internet, especially for queries about recent events.
""",
)

return web_agent_tool


def get_tool_objects(tool_names: List[str]) -> List:
"""Get tool objects from tool names."""
# construct additional tools
tool_objs = []
for tool_name in tool_names:
if tool_name == "web_search":
# build web agent
tool_objs.append(get_web_agent_tool())
else:
raise ValueError(f"Tool {tool_name} not recognized.")

return tool_objs


class ParamCache(BaseModel):
"""Cache for RAG agent builder.

Expand Down Expand Up @@ -338,7 +421,7 @@ def save_to_disk(self, save_dir: str) -> None:
"file_names": self.file_names,
"urls": self.urls,
# TODO: figure out tools
# "tools": [],
"tools": self.tools,
"rag_params": self.rag_params.dict(),
"agent_id": self.agent_id,
}
Expand Down Expand Up @@ -376,11 +459,13 @@ def load_from_disk(
file_names=cache_dict["file_names"], urls=cache_dict["urls"]
)
# load agent from index
additional_tools = get_tool_objects(cache_dict["tools"])
agent, _ = construct_agent(
cache_dict["system_prompt"],
cache_dict["rag_params"],
cache_dict["docs"],
vector_index=vector_index,
additional_tools=additional_tools,
# TODO: figure out tools
)
cache_dict["vector_index"] = vector_index
Expand Down Expand Up @@ -505,20 +590,14 @@ def load_data(
self._cache.urls = urls
return "Data loaded successfully."

# NOTE: unused
def add_web_tool(self) -> str:
"""Add a web tool to enable agent to solve a task."""
# TODO: make this not hardcoded to a web tool
# Set up Metaphor tool
from llama_hub.tools.metaphor.base import MetaphorToolSpec

# TODO: set metaphor API key
metaphor_tool = MetaphorToolSpec(
api_key=os.environ["METAPHOR_API_KEY"],
)
metaphor_tool_list = metaphor_tool.to_tool_list()

self._cache.tools.extend(metaphor_tool_list)
if "web_search" in self._cache.tools:
return "Web tool already added."
else:
self._cache.tools.append("web_search")
return "Web tool added successfully."

def get_rag_params(self) -> Dict:
Expand Down Expand Up @@ -557,11 +636,13 @@ def create_agent(self, agent_id: Optional[str] = None) -> str:
if self._cache.system_prompt is None:
raise ValueError("Must set system prompt before creating agent.")

# construct additional tools
additional_tools = get_tool_objects(self.cache.tools)
agent, extra_info = construct_agent(
cast(str, self._cache.system_prompt),
cast(RAGParams, self._cache.rag_params),
self._cache.docs,
additional_tools=self._cache.tools,
additional_tools=additional_tools,
)

# if agent_id not specified, randomly generate one
Expand All @@ -587,6 +668,7 @@ def update_agent(
chunk_size: Optional[int] = None,
embed_model: Optional[str] = None,
llm: Optional[str] = None,
additional_tools: Optional[List] = None,
) -> None:
"""Update agent.

Expand All @@ -609,7 +691,6 @@ def update_agent(
# We call set_rag_params and create_agent, which will
# update the cache
# TODO: decouple functions from tool functions exposed to the agent

rag_params_dict: Dict[str, Any] = {}
if include_summarization is not None:
rag_params_dict["include_summarization"] = include_summarization
Expand All @@ -623,6 +704,11 @@ def update_agent(
rag_params_dict["llm"] = llm

self.set_rag_params(**rag_params_dict)

# update tools
if additional_tools is not None:
self.cache.tools = additional_tools

# this will update the agent in the cache
self.create_agent()

Expand Down Expand Up @@ -655,6 +741,33 @@ def update_agent(
# please make sure to update the LLM above if you change the function below


def _get_builder_agent_tools(agent_builder: RAGAgentBuilder) -> List[FunctionTool]:
"""Get list of builder agent tools to pass to the builder agent."""
# see if metaphor api key is set, otherwise don't add web tool
# TODO: refactor this later

if "metaphor_key" in st.secrets:
fns: List[Callable] = [
agent_builder.create_system_prompt,
agent_builder.load_data,
agent_builder.add_web_tool,
agent_builder.get_rag_params,
agent_builder.set_rag_params,
agent_builder.create_agent,
]
else:
fns = [
agent_builder.create_system_prompt,
agent_builder.load_data,
agent_builder.get_rag_params,
agent_builder.set_rag_params,
agent_builder.create_agent,
]

fn_tools: List[FunctionTool] = [FunctionTool.from_defaults(fn=fn) for fn in fns]
return fn_tools


# define agent
# @st.cache_resource
def load_meta_agent_and_tools(
Expand All @@ -664,15 +777,7 @@ def load_meta_agent_and_tools(
# think of this as tools for the agent to use
agent_builder = RAGAgentBuilder(cache)

fns: List[Callable] = [
agent_builder.create_system_prompt,
agent_builder.load_data,
# add_web_tool,
agent_builder.get_rag_params,
agent_builder.set_rag_params,
agent_builder.create_agent,
]
fn_tools = [FunctionTool.from_defaults(fn=fn) for fn in fns]
fn_tools = _get_builder_agent_tools(agent_builder)

builder_agent = load_meta_agent(
fn_tools, llm=BUILDER_LLM, system_prompt=RAG_BUILDER_SYS_STR, verbose=True
Expand Down
70 changes: 70 additions & 0 deletions callback_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Streaming callback manager."""
from llama_index.callbacks.base_handler import BaseCallbackHandler
from llama_index.callbacks.schema import CBEventType

from typing import Optional, Dict, Any, List, Callable

STORAGE_DIR = "./storage" # directory to cache the generated index
DATA_DIR = "./data" # directory containing the documents to index


class StreamlitFunctionsCallbackHandler(BaseCallbackHandler):
"""Callback handler that outputs streamlit components given events."""

def __init__(self, msg_handler: Callable[[str], Any]) -> None:
"""Initialize the base callback handler."""
self.msg_handler = msg_handler
super().__init__([], [])

def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
parent_id: str = "",
**kwargs: Any,
) -> str:
"""Run when an event starts and return id of event."""
if event_type == CBEventType.FUNCTION_CALL:
if payload is None:
raise ValueError("Payload cannot be None")
arguments_str = payload["function_call"]
tool_str = payload["tool"].name
print_str = f"Calling function: {tool_str} with args: {arguments_str}\n\n"
self.msg_handler(print_str)
else:
pass
return event_id

def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
"""Run when an event ends."""
pass
# TODO: currently we don't need to do anything here
# if event_type == CBEventType.FUNCTION_CALL:
# response = payload["function_call_response"]
# # Add this to queue
# print_str = (
# f"\n\nGot output: {response}\n"
# "========================\n\n"
# )
# elif event_type == CBEventType.AGENT_STEP:
# # put response into queue
# self._queue.put(payload["response"])

def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Run when an overall trace is launched."""
pass

def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""Run when an overall trace is exited."""
pass
10 changes: 10 additions & 0 deletions pages/2_⚙️_RAG_Config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def update_agent() -> None:
"config_agent_builder" in st.session_state.keys()
and st.session_state.config_agent_builder is not None
):
additional_tools = st.session_state.additional_tools_st.split(",")
agent_builder = cast(RAGAgentBuilder, st.session_state.config_agent_builder)
### Update the agent
agent_builder.update_agent(
Expand All @@ -34,6 +35,7 @@ def update_agent() -> None:
chunk_size=st.session_state.chunk_size_st,
embed_model=st.session_state.embed_model_st,
llm=st.session_state.llm_st,
additional_tools=additional_tools,
)

# Update Radio Buttons: update selected agent to the new id
Expand Down Expand Up @@ -114,6 +116,14 @@ def delete_agent() -> None:
value=rag_params.include_summarization,
key="include_summarization_st",
)

# add web tool
additional_tools_st = st.text_input(
"Additional tools (currently only supports 'web_search')",
value=",".join(agent_builder.cache.tools),
key="additional_tools_st",
)

top_k_st = st.number_input("Top K", value=rag_params.top_k, key="top_k_st")
chunk_size_st = st.number_input(
"Chunk Size", value=rag_params.chunk_size, key="chunk_size_st"
Expand Down
Loading