diff --git "a/1_\360\237\217\240_Home.py" "b/1_\360\237\217\240_Home.py" index 672413c..795f271 100644 --- "a/1_\360\237\217\240_Home.py" +++ "b/1_\360\237\217\240_Home.py" @@ -31,6 +31,8 @@ "To build a new agent, please make sure that 'Create a new agent' is selected.", icon="ℹ️", ) +if "metaphor_key" in st.secrets: + st.info("**NOTE**: The ability to add web search is enabled.") add_sidebar() diff --git a/agent_utils.py b/agent_utils.py index d65f0a8..c6ff234 100644 --- a/agent_utils.py +++ b/agent_utils.py @@ -32,6 +32,9 @@ from constants import AGENT_CACHE_DIR import shutil +from llama_index.callbacks import CallbackManager +from callback_manager import StreamlitFunctionsCallbackHandler + def _resolve_llm(llm_str: str) -> LLM: """Resolve LLM.""" @@ -153,9 +156,25 @@ def load_agent( """Load agent.""" extra_kwargs = extra_kwargs or {} if isinstance(llm, OpenAI) and is_function_calling_model(llm.model): + # TODO: use default msg handler + # TODO: separate this from agent_utils.py... + def _msg_handler(msg: str) -> None: + """Message handler.""" + st.info(msg) + st.session_state.agent_messages.append( + {"role": "assistant", "content": msg, "msg_type": "info"} + ) + + # add streamlit callbacks (to inject events) + handler = StreamlitFunctionsCallbackHandler(_msg_handler) + callback_manager = CallbackManager([handler]) # get OpenAI Agent agent: BaseChatEngine = OpenAIAgent.from_tools( - tools=tools, llm=llm, system_prompt=system_prompt, **kwargs + tools=tools, + llm=llm, + system_prompt=system_prompt, + **kwargs, + callback_manager=callback_manager, ) else: if "vector_index" not in extra_kwargs: @@ -189,8 +208,12 @@ def load_meta_agent( extra_kwargs = extra_kwargs or {} if isinstance(llm, OpenAI) and is_function_calling_model(llm.model): # get OpenAI Agent + agent: BaseAgent = OpenAIAgent.from_tools( - tools=tools, llm=llm, system_prompt=system_prompt, **kwargs + tools=tools, + llm=llm, + system_prompt=system_prompt, + **kwargs, ) else: agent = ReActAgent.from_tools( @@ -285,6 +308,66 @@ def construct_agent( return agent, extra_info +def get_web_agent_tool() -> QueryEngineTool: + """Get web agent tool. + + Wrap with our load and search tool spec. + + """ + from llama_hub.tools.metaphor.base import MetaphorToolSpec + + # TODO: set metaphor API key + metaphor_tool = MetaphorToolSpec( + api_key=st.secrets.metaphor_key, + ) + metaphor_tool_list = metaphor_tool.to_tool_list() + + # TODO: LoadAndSearch doesn't work yet + # The search_and_retrieve_documents tool is the third in the tool list, + # as seen above + # wrapped_retrieve = LoadAndSearchToolSpec.from_defaults( + # metaphor_tool_list[2], + # ) + + # NOTE: requires openai right now + # We don't give the Agent our unwrapped retrieve document tools + # instead passing the wrapped tools + web_agent = OpenAIAgent.from_tools( + # [*wrapped_retrieve.to_tool_list(), metaphor_tool_list[4]], + metaphor_tool_list, + llm=BUILDER_LLM, + verbose=True, + ) + + # return agent as a tool + # TODO: tune description + web_agent_tool = QueryEngineTool.from_defaults( + web_agent, + name="web_agent", + description=""" + This agent can answer questions by searching the web. \ +Use this tool if the answer is ONLY likely to be found by searching \ +the internet, especially for queries about recent events. + """, + ) + + return web_agent_tool + + +def get_tool_objects(tool_names: List[str]) -> List: + """Get tool objects from tool names.""" + # construct additional tools + tool_objs = [] + for tool_name in tool_names: + if tool_name == "web_search": + # build web agent + tool_objs.append(get_web_agent_tool()) + else: + raise ValueError(f"Tool {tool_name} not recognized.") + + return tool_objs + + class ParamCache(BaseModel): """Cache for RAG agent builder. @@ -338,7 +421,7 @@ def save_to_disk(self, save_dir: str) -> None: "file_names": self.file_names, "urls": self.urls, # TODO: figure out tools - # "tools": [], + "tools": self.tools, "rag_params": self.rag_params.dict(), "agent_id": self.agent_id, } @@ -376,11 +459,13 @@ def load_from_disk( file_names=cache_dict["file_names"], urls=cache_dict["urls"] ) # load agent from index + additional_tools = get_tool_objects(cache_dict["tools"]) agent, _ = construct_agent( cache_dict["system_prompt"], cache_dict["rag_params"], cache_dict["docs"], vector_index=vector_index, + additional_tools=additional_tools, # TODO: figure out tools ) cache_dict["vector_index"] = vector_index @@ -505,20 +590,14 @@ def load_data( self._cache.urls = urls return "Data loaded successfully." - # NOTE: unused def add_web_tool(self) -> str: """Add a web tool to enable agent to solve a task.""" # TODO: make this not hardcoded to a web tool # Set up Metaphor tool - from llama_hub.tools.metaphor.base import MetaphorToolSpec - - # TODO: set metaphor API key - metaphor_tool = MetaphorToolSpec( - api_key=os.environ["METAPHOR_API_KEY"], - ) - metaphor_tool_list = metaphor_tool.to_tool_list() - - self._cache.tools.extend(metaphor_tool_list) + if "web_search" in self._cache.tools: + return "Web tool already added." + else: + self._cache.tools.append("web_search") return "Web tool added successfully." def get_rag_params(self) -> Dict: @@ -557,11 +636,13 @@ def create_agent(self, agent_id: Optional[str] = None) -> str: if self._cache.system_prompt is None: raise ValueError("Must set system prompt before creating agent.") + # construct additional tools + additional_tools = get_tool_objects(self.cache.tools) agent, extra_info = construct_agent( cast(str, self._cache.system_prompt), cast(RAGParams, self._cache.rag_params), self._cache.docs, - additional_tools=self._cache.tools, + additional_tools=additional_tools, ) # if agent_id not specified, randomly generate one @@ -587,6 +668,7 @@ def update_agent( chunk_size: Optional[int] = None, embed_model: Optional[str] = None, llm: Optional[str] = None, + additional_tools: Optional[List] = None, ) -> None: """Update agent. @@ -609,7 +691,6 @@ def update_agent( # We call set_rag_params and create_agent, which will # update the cache # TODO: decouple functions from tool functions exposed to the agent - rag_params_dict: Dict[str, Any] = {} if include_summarization is not None: rag_params_dict["include_summarization"] = include_summarization @@ -623,6 +704,11 @@ def update_agent( rag_params_dict["llm"] = llm self.set_rag_params(**rag_params_dict) + + # update tools + if additional_tools is not None: + self.cache.tools = additional_tools + # this will update the agent in the cache self.create_agent() @@ -655,6 +741,33 @@ def update_agent( # please make sure to update the LLM above if you change the function below +def _get_builder_agent_tools(agent_builder: RAGAgentBuilder) -> List[FunctionTool]: + """Get list of builder agent tools to pass to the builder agent.""" + # see if metaphor api key is set, otherwise don't add web tool + # TODO: refactor this later + + if "metaphor_key" in st.secrets: + fns: List[Callable] = [ + agent_builder.create_system_prompt, + agent_builder.load_data, + agent_builder.add_web_tool, + agent_builder.get_rag_params, + agent_builder.set_rag_params, + agent_builder.create_agent, + ] + else: + fns = [ + agent_builder.create_system_prompt, + agent_builder.load_data, + agent_builder.get_rag_params, + agent_builder.set_rag_params, + agent_builder.create_agent, + ] + + fn_tools: List[FunctionTool] = [FunctionTool.from_defaults(fn=fn) for fn in fns] + return fn_tools + + # define agent # @st.cache_resource def load_meta_agent_and_tools( @@ -664,15 +777,7 @@ def load_meta_agent_and_tools( # think of this as tools for the agent to use agent_builder = RAGAgentBuilder(cache) - fns: List[Callable] = [ - agent_builder.create_system_prompt, - agent_builder.load_data, - # add_web_tool, - agent_builder.get_rag_params, - agent_builder.set_rag_params, - agent_builder.create_agent, - ] - fn_tools = [FunctionTool.from_defaults(fn=fn) for fn in fns] + fn_tools = _get_builder_agent_tools(agent_builder) builder_agent = load_meta_agent( fn_tools, llm=BUILDER_LLM, system_prompt=RAG_BUILDER_SYS_STR, verbose=True diff --git a/callback_manager.py b/callback_manager.py new file mode 100644 index 0000000..c8bf6da --- /dev/null +++ b/callback_manager.py @@ -0,0 +1,70 @@ +"""Streaming callback manager.""" +from llama_index.callbacks.base_handler import BaseCallbackHandler +from llama_index.callbacks.schema import CBEventType + +from typing import Optional, Dict, Any, List, Callable + +STORAGE_DIR = "./storage" # directory to cache the generated index +DATA_DIR = "./data" # directory containing the documents to index + + +class StreamlitFunctionsCallbackHandler(BaseCallbackHandler): + """Callback handler that outputs streamlit components given events.""" + + def __init__(self, msg_handler: Callable[[str], Any]) -> None: + """Initialize the base callback handler.""" + self.msg_handler = msg_handler + super().__init__([], []) + + def on_event_start( + self, + event_type: CBEventType, + payload: Optional[Dict[str, Any]] = None, + event_id: str = "", + parent_id: str = "", + **kwargs: Any, + ) -> str: + """Run when an event starts and return id of event.""" + if event_type == CBEventType.FUNCTION_CALL: + if payload is None: + raise ValueError("Payload cannot be None") + arguments_str = payload["function_call"] + tool_str = payload["tool"].name + print_str = f"Calling function: {tool_str} with args: {arguments_str}\n\n" + self.msg_handler(print_str) + else: + pass + return event_id + + def on_event_end( + self, + event_type: CBEventType, + payload: Optional[Dict[str, Any]] = None, + event_id: str = "", + **kwargs: Any, + ) -> None: + """Run when an event ends.""" + pass + # TODO: currently we don't need to do anything here + # if event_type == CBEventType.FUNCTION_CALL: + # response = payload["function_call_response"] + # # Add this to queue + # print_str = ( + # f"\n\nGot output: {response}\n" + # "========================\n\n" + # ) + # elif event_type == CBEventType.AGENT_STEP: + # # put response into queue + # self._queue.put(payload["response"]) + + def start_trace(self, trace_id: Optional[str] = None) -> None: + """Run when an overall trace is launched.""" + pass + + def end_trace( + self, + trace_id: Optional[str] = None, + trace_map: Optional[Dict[str, List[str]]] = None, + ) -> None: + """Run when an overall trace is exited.""" + pass diff --git "a/pages/2_\342\232\231\357\270\217_RAG_Config.py" "b/pages/2_\342\232\231\357\270\217_RAG_Config.py" index e2e8d70..58d992e 100644 --- "a/pages/2_\342\232\231\357\270\217_RAG_Config.py" +++ "b/pages/2_\342\232\231\357\270\217_RAG_Config.py" @@ -24,6 +24,7 @@ def update_agent() -> None: "config_agent_builder" in st.session_state.keys() and st.session_state.config_agent_builder is not None ): + additional_tools = st.session_state.additional_tools_st.split(",") agent_builder = cast(RAGAgentBuilder, st.session_state.config_agent_builder) ### Update the agent agent_builder.update_agent( @@ -34,6 +35,7 @@ def update_agent() -> None: chunk_size=st.session_state.chunk_size_st, embed_model=st.session_state.embed_model_st, llm=st.session_state.llm_st, + additional_tools=additional_tools, ) # Update Radio Buttons: update selected agent to the new id @@ -114,6 +116,14 @@ def delete_agent() -> None: value=rag_params.include_summarization, key="include_summarization_st", ) + + # add web tool + additional_tools_st = st.text_input( + "Additional tools (currently only supports 'web_search')", + value=",".join(agent_builder.cache.tools), + key="additional_tools_st", + ) + top_k_st = st.number_input("Top K", value=rag_params.top_k, key="top_k_st") chunk_size_st = st.number_input( "Chunk Size", value=rag_params.chunk_size, key="chunk_size_st" diff --git "a/pages/3_\360\237\244\226_Generated_RAG_Agent.py" "b/pages/3_\360\237\244\226_Generated_RAG_Agent.py" index 1afbdbd..4f897b0 100644 --- "a/pages/3_\360\237\244\226_Generated_RAG_Agent.py" +++ "b/pages/3_\360\237\244\226_Generated_RAG_Agent.py" @@ -34,6 +34,19 @@ def add_to_message_history(role: str, content: str) -> None: st.session_state.agent_messages.append(message) # Add response to message history +def display_messages() -> None: + """Display messages.""" + for message in st.session_state.agent_messages: # Display the prior chat messages + with st.chat_message(message["role"]): + msg_type = message["msg_type"] if "msg_type" in message.keys() else "text" + if msg_type == "text": + st.write(message["content"]) + elif msg_type == "info": + st.info(message["content"], icon="ℹ️") + else: + raise ValueError(f"Unknown message type: {msg_type}") + + # first, pick the cache: this is preloaded from an existing agent, # or is part of the current one being created agent = None @@ -53,9 +66,9 @@ def add_to_message_history(role: str, content: str) -> None: if cache is not None and cache.agent is not None: st.info(f"Viewing config for agent: {cache.agent_id}", icon="ℹ️") agent = cache.agent - for message in st.session_state.agent_messages: # Display the prior chat messages - with st.chat_message(message["role"]): - st.write(message["content"]) + + # display prior messages + display_messages() # don't process selected for now if prompt := st.chat_input( diff --git a/pyproject.toml b/pyproject.toml index 6870398..26b9b88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "rags" -version = "0.0.2" +version = "0.0.3" description = "Build RAG with natural language." authors = ["Jerry Liu"] # New attributes @@ -36,6 +36,8 @@ isort = "5.11.4" pytest-asyncio = "^0.21.1" ruff = "0.0.285" mypy = "0.991" +referencing = "0.30.2" +jsonschema-specifications = "2023.7.1" [build-system] requires = ["poetry>=0.12", "poetry-core>=1.0.0"]